--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <vector>
+#include <memory>
+
+#include <vector>
+#include <memory>
+
+#include <transformations_visibility.hpp>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API ConvertTensorIteratorToLSTMSequence;
+class TRANSFORMATIONS_API ConvertTensorIteratorToRNNSequence;
+class TRANSFORMATIONS_API ConvertTensorIteratorToGRUSequence;
+
+} // namespace pass
+} // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief Finds all TensorIterator layers, detects the pattern Squeeze->LSTMCell->Unsqueeze in the TensorIterator body,
+ * converts this pattern to LSTMSequence layer and replaces them TensorIterator.
+ */
+
+class ngraph::pass::ConvertTensorIteratorToLSTMSequence: public ngraph::pass::MatcherPass {
+public:
+ ConvertTensorIteratorToLSTMSequence();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief Finds all TensorIterator layers, detects the pattern Squeeze->RNNCell->Unsqueeze in the TensorIterator body,
+ * converts this pattern to RNNSequence layer and replaces them TensorIterator.
+ */
+
+class ngraph::pass::ConvertTensorIteratorToRNNSequence: public ngraph::pass::MatcherPass {
+public:
+ ConvertTensorIteratorToRNNSequence();
+};
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief Finds all TensorIterator layers, detects the pattern Squeeze->GRUCell->Unsqueeze in the TensorIterator body,
+ * converts this pattern to GRUSequence layer and replaces them TensorIterator.
+ */
+
+class ngraph::pass::ConvertTensorIteratorToGRUSequence: public ngraph::pass::MatcherPass {
+public:
+ ConvertTensorIteratorToGRUSequence();
+};
\ No newline at end of file
return false;
}
- const auto& W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
- lstm_sequence->input_value(4).get_node_shared_ptr());
- if (!W) {
- return false;
- }
+ const auto& W = lstm_sequence->input_value(4);
+ const auto& R = lstm_sequence->input_value(5);
- const auto& R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
- lstm_sequence->input_value(5).get_node_shared_ptr());
- if (!R) {
+ // Bidirectional cases are not supported
+ if (lstm_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
- }
// for forward/reverse cases we can squeeze num_direction dimension
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
- auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(1).get_source_output(), axis_1);
- auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(2).get_source_output(), axis_1);
- auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
+ auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(1), axis_1);
+ auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(2), axis_1);
+ auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
- auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(6).get_source_output(), axis_2);
+ auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input_value(6), axis_2);
auto lstm_sequence_ie = std::make_shared<ngraph::op::LSTMSequenceIE>(
lstm_sequence->input(0).get_source_output(), // X
in_1, // initial_hidden_state
in_2, // initial_cell_state
- lstm_sequence->input(3).get_source_output(),
+ lstm_sequence->input_value(3),
in_3, // WR
in_4, // B
lstm_sequence->get_hidden_size(),
return false;
}
- auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
- gru_sequence->input_value(3).get_node_shared_ptr());
- if (!W) {
- return false;
- }
+ auto W = gru_sequence->input_value(3);
+ auto R = gru_sequence->input_value(4);
- auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
- gru_sequence->input_value(4).get_node_shared_ptr());
- if (!R) {
- return false;
- }
-
- // todo: add exception?
+ // Bidirectional cases are not supported
if (gru_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
// for forward/reverse cases we can squeeze num_direction dimension
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
- auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(1).get_source_output(), axis_1);
- auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
+ auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input_value(1), axis_1);
+ auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
- auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(5).get_source_output(), axis_2);
+ auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input_value(5), axis_2);
auto gru_sequence_ie = std::make_shared<ngraph::op::GRUSequenceIE>(
- gru_sequence->input(0).get_source_output(), // X
+ gru_sequence->input_value(0), // X
in_1, // initial_hidden_state
- gru_sequence->input(2).get_source_output(),
+ gru_sequence->input_value(2),
in_3, // WR
in_4, // B
gru_sequence->get_hidden_size(),
return false;
}
- auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
- rnn_sequence->input_value(3).get_node_shared_ptr());
- if (!W) {
+ // Bidirectional cases are not supported
+ if (rnn_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
return false;
- }
- auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
- rnn_sequence->input_value(4).get_node_shared_ptr());
- if (!R) {
- return false;
- }
+ auto W = rnn_sequence->input_value(3);
+ auto R = rnn_sequence->input_value(4);
// for forward/reverse cases we can squeeze num_direction dimension
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
- auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(1).get_source_output(), axis_1);
- auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
+ auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input_value(1), axis_1);
+ auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::OutputVector{W, R}, 2);
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
- auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(5).get_source_output(), axis_2);
+ auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input_value(5), axis_2);
auto rnn_sequence_ie = std::make_shared<ngraph::op::RNNSequenceIE>(
- rnn_sequence->input(0).get_source_output(), // X
+ rnn_sequence->input_value(0), // X
in_1, // initial_hidden_state
rnn_sequence->input_value(2),
in_3, // WR
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/tensor_iterator_transformations/convert_ti_to_sequences.h"
+#include "transformations/utils/utils.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/node.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/rt_info.hpp>
+#include <ngraph/graph_util.hpp>
+#include <ngraph/specialize_function.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+
+ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
+ auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
+ ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
+ ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
+ auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
+ if (!ti || !m_transformation_callback(ti))
+ return false;
+
+ // create pattern
+ auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
+ auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 1);
+
+ auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
+ auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
+ auto input_C_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
+ auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
+ auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
+ auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{4});
+
+ auto cell = std::make_shared<ngraph::opset4::LSTMCell>(input_data, input_H_state, input_C_state,
+ input_W, input_R, input_B, 1);
+
+ auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 1);
+ auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
+ ngraph::pattern::Matcher matcher(unsqueeze);
+
+ bool match = false;
+ auto func = ti->get_body();
+ for (const auto& res : func->get_results()) {
+ match = matcher.match((res->get_input_source_output(0)));
+ if (match)
+ break;
+ }
+
+ // All nodes are in the TI body should be matched in pattern
+ if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
+ return false;
+
+ auto pattern_map = matcher.get_pattern_map();
+
+ auto params = func->get_parameters();
+ std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
+ int64_t stride = 0, slice_axis = 0;
+ size_t batch_size = 0;
+ for (const auto& input_desc : ti->get_input_descriptions()) {
+ auto param = params[input_desc->m_body_parameter_index];
+ if (param == pattern_map[data]) {
+ // to get batch size value
+ if (param->get_partial_shape().is_dynamic()) {
+ return false;
+ }
+ auto slice_input
+ = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
+ if (!slice_input)
+ return false;
+
+ stride = slice_input->m_stride;
+ slice_axis = slice_input->m_axis;
+
+ if (!(slice_axis == 0 || slice_axis == 1)) {
+ return false;
+ }
+ batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
+ ordered_in_descs[0] = input_desc;
+ } else if (param == pattern_map[input_H_state]) {
+ ordered_in_descs[1] = input_desc;
+ } else if (param == pattern_map[input_C_state]) {
+ ordered_in_descs[2] = input_desc;
+ } else {
+ return false;
+ }
+ }
+
+ auto results = func->get_results();
+ std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(3);
+ for (const auto& output_desc : ti->get_output_descriptions()) {
+ std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
+ if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
+ auto concat_output
+ = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
+ if (!concat_output)
+ return false;
+
+ stride = concat_output->m_stride;
+ ordered_out_descs[0] = output_desc;
+ } else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
+ ordered_out_descs[1] = output_desc;
+ } else if (res->get_input_source_output(0) == pattern_map[cell]->output(1)) {
+ ordered_out_descs[2] = output_desc;
+ } else {
+ return false;
+ }
+ }
+
+ auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
+ const auto& lstm_cell = std::dynamic_pointer_cast<ngraph::opset4::LSTMCell>(pattern_map[cell]);
+ auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
+ if (slice_axis == 0) {
+ auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
+ in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
+ }
+ auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
+ auto in_2 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[2]->m_input_index], axis_1);
+
+ auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
+ auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
+ auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
+ auto in_6 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
+ auto sequence = std::make_shared<op::v5::LSTMSequence>(
+ in_0,
+ in_1,
+ in_2,
+ seq_lengths,
+ in_4,
+ in_5,
+ in_6,
+ lstm_cell->get_hidden_size(),
+ stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
+ lstm_cell->get_activations_alpha(),
+ lstm_cell->get_activations_beta(),
+ lstm_cell->get_activations(),
+ lstm_cell->get_clip());
+
+ auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
+ auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
+ auto out_2 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(2), axis_out);
+
+ std::shared_ptr<Node> out = out_0;
+ if (slice_axis == 0) {
+ auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
+ out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
+ }
+
+ ngraph::NodeVector outputs = {out, out_1, out_2};
+ for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
+ if (ordered_out_descs[i]) {
+ for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
+ input.replace_source_output(outputs[i]->output(0));
+ }
+ outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
+ }
+ }
+
+ ngraph::NodeVector new_nodes = {in_1, in_2, in_4, in_5, in_6, sequence, out_0, out_1, out_2};
+ if (slice_axis == 0) {
+ new_nodes.push_back(out);
+ new_nodes.push_back(in_0.get_node_shared_ptr());
+ }
+ copy_runtime_info(ti, new_nodes);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToLSTMSequence");
+ register_matcher(m, callback);
+}
+
+ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
+ auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
+ ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
+ ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
+ auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
+ if (!ti || !m_transformation_callback(ti))
+ return false;
+
+ // create pattern
+ auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
+ auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
+ auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
+
+ auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
+ auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
+ auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
+ auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{1});
+
+ auto cell = std::make_shared<ngraph::opset4::RNNCell>(input_data, input_H_state, input_W, input_R, input_B, 1);
+
+ auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
+ auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
+ ngraph::pattern::Matcher matcher(unsqueeze);
+
+ bool match = false;
+ auto func = ti->get_body();
+ for (const auto& res : func->get_results()) {
+ match = matcher.match((res->get_input_source_output(0)));
+ if (match)
+ break;
+ }
+
+ // All nodes are in the TI body should be matched in pattern
+ if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
+ return false;
+
+ auto pattern_map = matcher.get_pattern_map();
+
+ auto params = func->get_parameters();
+ std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
+ int64_t stride = 0, slice_axis = 0;
+ size_t batch_size = 0;
+ for (const auto& input_desc : ti->get_input_descriptions()) {
+ auto param = params[input_desc->m_body_parameter_index];
+ if (param == pattern_map[data]) {
+ // to get batch size value
+ if (param->get_partial_shape().is_dynamic()) {
+ return false;
+ }
+ auto slice_input
+ = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
+ if (!slice_input)
+ return false;
+
+ stride = slice_input->m_stride;
+ slice_axis = slice_input->m_axis;
+ if (!(slice_axis == 0 || slice_axis == 1)) {
+ return false;
+ }
+ batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
+ ordered_in_descs[0] = input_desc;
+ } else if (param == pattern_map[input_H_state]) {
+ ordered_in_descs[1] = input_desc;
+ } else {
+ return false;
+ }
+ }
+
+ auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
+
+ auto results = func->get_results();
+ std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(2);
+ for (const auto& output_desc : ti->get_output_descriptions()) {
+ std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
+ if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
+ auto concat_output
+ = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
+ if (!concat_output)
+ return false;
+
+ stride = concat_output->m_stride;
+ ordered_out_descs[0] = output_desc;
+ } else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
+ ordered_out_descs[1] = output_desc;
+ } else {
+ return false;
+ }
+ }
+
+ const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::RNNCell>(pattern_map[cell]);
+
+ auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
+ if (slice_axis == 0) {
+ auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
+ in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
+ }
+
+ auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
+
+ auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
+ auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
+ auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
+ auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
+ auto sequence = std::make_shared<op::v5::RNNSequence>(
+ in_0,
+ in_1,
+ seq_lengths,
+ in_3,
+ in_4,
+ in_5,
+ rnn_cell->get_hidden_size(),
+ stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
+ rnn_cell->get_activations(),
+ rnn_cell->get_activations_alpha(),
+ rnn_cell->get_activations_beta(),
+ rnn_cell->get_clip());
+
+ auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
+ auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
+
+ std::shared_ptr<Node> out = out_0;
+ if (slice_axis == 0) {
+ auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
+ out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
+ }
+
+ ngraph::NodeVector outputs = {out, out_1};
+ for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
+ if (ordered_out_descs[i]) {
+ for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
+ input.replace_source_output(outputs[i]->output(0));
+ }
+ outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
+ }
+ }
+
+ ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
+ if (slice_axis == 0) {
+ new_nodes.push_back(out);
+ new_nodes.push_back(in_0);
+ }
+ copy_runtime_info(ti, as_node_vector(new_nodes));
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToRNNSequence");
+ register_matcher(m, callback);
+}
+
+ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
+ auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
+ ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset4::TensorIterator>());
+ ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
+ auto ti = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator>(m.get_match_root());
+ if (!ti || !m_transformation_callback(ti))
+ return false;
+
+ // create pattern
+ auto data = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
+ auto axis_squeeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
+ auto input_data = std::make_shared<ngraph::opset4::Squeeze>(data, axis_squeeze);
+
+ auto input_H_state = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
+ auto input_W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
+ auto input_R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
+ auto input_B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{3});
+
+ auto cell = std::make_shared<ngraph::opset4::GRUCell>(input_data, input_H_state, input_W, input_R, input_B, 1);
+
+ auto axis_unsqueeze = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{1}, 0);
+ auto unsqueeze = std::make_shared<ngraph::opset4::Unsqueeze>(cell, axis_unsqueeze);
+ ngraph::pattern::Matcher matcher(unsqueeze);
+
+ bool match = false;
+ auto func = ti->get_body();
+ for (const auto& res : func->get_results()) {
+ match = matcher.match((res->get_input_source_output(0)));
+ if (match)
+ break;
+ }
+
+ // All nodes are in the TI body should be matched in pattern
+ if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
+ return false;
+
+ auto pattern_map = matcher.get_pattern_map();
+
+ auto params = func->get_parameters();
+ std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::InputDescription>> ordered_in_descs(3);
+ int64_t stride = 0, slice_axis = 0;
+ size_t batch_size = 0;
+ for (const auto& input_desc : ti->get_input_descriptions()) {
+ auto param = params[input_desc->m_body_parameter_index];
+ if (param == pattern_map[data]) {
+ // to get batch size value
+ if (param->get_partial_shape().is_dynamic()) {
+ return false;
+ }
+ auto slice_input
+ = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::SliceInputDescription>(input_desc);
+ if (!slice_input)
+ return false;
+
+ stride = slice_input->m_stride;
+ slice_axis = slice_input->m_axis;
+ if (!(slice_axis == 0 || slice_axis == 1)) {
+ return false;
+ }
+ batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
+ ordered_in_descs[0] = input_desc;
+ } else if (param == pattern_map[input_H_state]) {
+ ordered_in_descs[1] = input_desc;
+ } else {
+ return false;
+ }
+ }
+
+ auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
+
+ auto results = func->get_results();
+ std::vector<std::shared_ptr<ngraph::opset4::TensorIterator::OutputDescription>> ordered_out_descs(2);
+ for (const auto& output_desc : ti->get_output_descriptions()) {
+ std::shared_ptr<opset4::Result> res = results[output_desc->m_body_value_index];
+ if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
+ auto concat_output
+ = std::dynamic_pointer_cast<ngraph::opset4::TensorIterator::ConcatOutputDescription>(output_desc);
+ if (!concat_output)
+ return false;
+
+ stride = concat_output->m_stride;
+ ordered_out_descs[0] = output_desc;
+ } else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
+ ordered_out_descs[1] = output_desc;
+ } else {
+ return false;
+ }
+ }
+
+ const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::GRUCell>(pattern_map[cell]);
+
+ auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
+ if (slice_axis == 0) {
+ auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
+ in_0 = std::make_shared<ngraph::opset4::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
+ }
+
+ auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
+
+ auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
+ auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
+ auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
+ auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
+ auto sequence = std::make_shared<op::v5::GRUSequence>(
+ in_0,
+ in_1,
+ seq_lengths,
+ in_3,
+ in_4,
+ in_5,
+ rnn_cell->get_hidden_size(),
+ stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
+ rnn_cell->get_activations(),
+ rnn_cell->get_activations_alpha(),
+ rnn_cell->get_activations_beta(),
+ rnn_cell->get_clip(),
+ rnn_cell->get_linear_before_reset());
+
+ auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(0), axis_out);
+ auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(sequence->output(1), axis_out);
+
+ std::shared_ptr<Node> out = out_0;
+ if (slice_axis == 0) {
+ auto order = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
+ out = std::make_shared<ngraph::opset4::Transpose>(out_0, order);
+ }
+
+ ngraph::NodeVector outputs = {out, out_1};
+ for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
+ if (ordered_out_descs[i]) {
+ for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
+ input.replace_source_output(outputs[i]->output(0));
+ }
+ outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
+ }
+ }
+
+ ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
+ if (slice_axis == 0) {
+ new_nodes.push_back(out);
+ new_nodes.push_back(in_0);
+ }
+ copy_runtime_info(ti, as_node_vector(new_nodes));
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToGRUSequence");
+ register_matcher(m, callback);
+}
\ No newline at end of file
--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include "common_test_utils/test_common.hpp"
+#include <string>
+#include <memory>
+#include <queue>
+
+#include <ngraph/pass/manager.hpp>
+#include <ngraph/function.hpp>
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph_ops/fully_connected.hpp>
+#include <transformations/tensor_iterator_transformations/convert_ti_to_sequences.h>
+#include <transformations/utils/utils.hpp>
+#include <transformations/init_node_info.hpp>
+
+#include "common_test_utils/ngraph_test_utils.hpp"
+
+using namespace testing;
+using namespace ngraph;
+
+TEST(TransformationTests, ConvertTensorIteratorToLSTMSequence) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
+ auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+ auto Z = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
+ auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+ auto Zi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ // Body
+ auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
+
+ auto w_val = std::vector<float>(512 * 16, 0);
+ auto r_val = std::vector<float>(512 * 128, 0);
+ auto b_val = std::vector<float>(512, 0);
+ auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val);
+ auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val);
+ auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
+
+ auto lstm_cell = std::make_shared<opset4::LSTMCell>(squeeze, Yi, Zi, W, R, B, 128);
+ auto res_1 = std::make_shared<opset4::Result>(lstm_cell);
+ auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto unsqueeze = std::make_shared<opset4::Unsqueeze>(lstm_cell, axis_unsqueeze);
+ auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
+ auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
+ ParameterVector{Xi, Yi, Zi});
+
+ auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
+ tensor_iterator->set_body(body);
+
+ tensor_iterator->set_invariant_input(Zi, Z);
+ tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
+ tensor_iterator->set_merged_input(Yi, Y, res_1);
+
+ auto out0 = tensor_iterator->get_iter_value(res_1, -1);
+ auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
+
+ auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
+ //auto res_ti_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
+ ngraph::ParameterVector{X, Y, Z});
+
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertTensorIteratorToLSTMSequence>();
+ m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
+ m.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
+ auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+ auto Z = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ auto w_val = std::vector<float>(512 * 16, 0);
+ auto r_val = std::vector<float>(512 * 128, 0);
+ auto b_val = std::vector<float>(512, 0);
+ auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val);
+ auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val);
+ auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val);
+
+ auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
+ auto in_2 = std::make_shared<ngraph::opset4::Unsqueeze>(Z, axis_1);
+
+ auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
+ auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
+ auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
+ auto in_6 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
+
+ auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
+ auto lstm_seq = std::make_shared<op::v5::LSTMSequence>(X, in_1, in_2, seq_lengths, in_4, in_5, in_6,
+ 128, ngraph::op::RecurrentSequenceDirection::FORWARD);
+ auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(0), axis_out);
+ auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(1), axis_out);
+ auto out_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_seq->output(1), axis_out);
+ auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y, Z});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertTensorIteratorToRNNSequence) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
+ auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
+ auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ // Body
+ auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
+
+ auto w_val = std::vector<float>(128 * 16, 0);
+ auto r_val = std::vector<float>(128 * 128, 0);
+ auto b_val = std::vector<float>(128, 0);
+ auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val);
+ auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val);
+ auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val);
+
+ auto rnn_cell = std::make_shared<opset4::RNNCell>(squeeze, Yi, W, R, B, 128);
+ auto res_1 = std::make_shared<opset4::Result>(rnn_cell);
+ auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto unsqueeze = std::make_shared<opset4::Unsqueeze>(rnn_cell, axis_unsqueeze);
+ auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
+ auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
+ ParameterVector{Xi, Yi});
+
+ auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
+ tensor_iterator->set_body(body);
+
+ tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
+ tensor_iterator->set_merged_input(Yi, Y, res_1);
+
+ auto out0 = tensor_iterator->get_iter_value(res_1, -1);
+ auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
+
+ auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
+ //auto res_ti_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
+ ngraph::ParameterVector{X, Y});
+
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertTensorIteratorToRNNSequence>();
+ m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
+ m.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
+ auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ auto w_val = std::vector<float>(128 * 16, 0);
+ auto r_val = std::vector<float>(128 * 128, 0);
+ auto b_val = std::vector<float>(128, 0);
+ auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val);
+ auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val);
+ auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val);
+
+ auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
+
+ auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
+ auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
+ auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
+ auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
+
+ auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
+ auto rnn_sequence = std::make_shared<op::v5::RNNSequence>(X, in_1, seq_lengths, in_3, in_4, in_5,
+ 128, ngraph::op::RecurrentSequenceDirection::FORWARD);
+ auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->output(0), axis_out);
+ auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->output(1), axis_out);
+ auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+TEST(TransformationTests, ConvertTensorIteratorToGRUSequence) {
+ std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
+ {
+ auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
+ auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ auto Xi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 1, 16});
+ auto Yi = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ // Body
+ auto axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto squeeze = std::make_shared<opset4::Squeeze>(Xi, axis);
+
+ auto w_val = std::vector<float>(384 * 16, 0);
+ auto r_val = std::vector<float>(384 * 128, 0);
+ auto b_val = std::vector<float>(384, 0);
+ auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val);
+ auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val);
+ auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val);
+
+ auto gru_cell = std::make_shared<opset4::GRUCell>(squeeze, Yi, W, R, B, 128);
+ auto res_1 = std::make_shared<opset4::Result>(gru_cell);
+ auto axis_unsqueeze = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto unsqueeze = std::make_shared<opset4::Unsqueeze>(gru_cell, axis_unsqueeze);
+ auto res_2 = std::make_shared<opset4::Result>(unsqueeze);
+ auto body = std::make_shared<Function>(OutputVector{res_1, res_2},
+ ParameterVector{Xi, Yi});
+
+ auto tensor_iterator = std::make_shared<opset4::TensorIterator>();
+ tensor_iterator->set_body(body);
+
+ tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
+ tensor_iterator->set_merged_input(Yi, Y, res_1);
+
+ auto out0 = tensor_iterator->get_iter_value(res_1, -1);
+ auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 1);
+
+ auto res_ti_1 = std::make_shared<opset4::Result>(tensor_iterator->output(1));
+ //auto res_tRNNCelli_2 = std::make_shared<opset4::Result>(tensor_iterator->output(0));
+ f = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1},
+ ngraph::ParameterVector{X, Y});
+
+ ngraph::pass::Manager m;
+ m.register_pass<ngraph::pass::InitNodeInfo>();
+ m.register_pass<ngraph::pass::ConvertTensorIteratorToGRUSequence>();
+ m.set_callback([](const std::shared_ptr<const Node>&) -> bool { return true; });
+ m.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ }
+
+ {
+ auto X = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 2, 16});
+ auto Y = std::make_shared<opset4::Parameter>(element::f32, Shape{1, 128});
+
+ auto w_val = std::vector<float>(384 * 16, 0);
+ auto r_val = std::vector<float>(384 * 128, 0);
+ auto b_val = std::vector<float>(384, 0);
+ auto W = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val);
+ auto R = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val);
+ auto B = ngraph::opset4::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val);
+
+ auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto in_1 = std::make_shared<ngraph::opset4::Unsqueeze>(Y, axis_1);
+
+ auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
+ auto in_3 = std::make_shared<ngraph::opset4::Unsqueeze>(W, axis_2);
+ auto in_4 = std::make_shared<ngraph::opset4::Unsqueeze>(R, axis_2);
+ auto in_5 = std::make_shared<ngraph::opset4::Unsqueeze>(B, axis_2);
+
+ auto seq_lengths = ngraph::opset4::Constant::create(element::i32, Shape{1}, {2});
+ auto gru_sequence = std::make_shared<op::v5::GRUSequence>(X, in_1, seq_lengths, in_3, in_4, in_5,
+ 128, ngraph::op::RecurrentSequenceDirection::FORWARD);
+ auto axis_out = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
+ auto out_0 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->output(0), axis_out);
+ auto out_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->output(1), axis_out);
+ auto res_ti_1 = std::make_shared<opset4::Result>(out_0);
+ f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y});
+ }
+
+ auto res = compare_functions(f, f_ref);
+ ASSERT_TRUE(res.first) << res.second;
+}
\ No newline at end of file