e439f278660f39dbc4af3d91b3fd23d880ca81d0
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / op_conversions / convert_ti_to_sequences.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "transformations/op_conversions/convert_ti_to_sequences.hpp"
6 #include "transformations/utils/utils.hpp"
7
8 #include <memory>
9 #include <vector>
10
11 #include <ngraph/node.hpp>
12 #include <ngraph/pass/manager.hpp>
13 #include <ngraph/opsets/opset5.hpp>
14 #include <ngraph/opsets/opset1.hpp>
15 #include <ngraph/rt_info.hpp>
16 #include <ngraph/graph_util.hpp>
17 #include <ngraph/specialize_function.hpp>
18 #include <ngraph/pattern/op/wrap_type.hpp>
19
20 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToLSTMSequence, "ConvertTensorIteratorToLSTMSequence", 0);
21 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToRNNSequence, "ConvertTensorIteratorToRNNSequence", 0);
22 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToGRUSequence, "ConvertTensorIteratorToGRUSequence", 0);
23
24 ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
25     auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
26                                                                         ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
27     ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
28         auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
29         if (!ti || m_transformation_callback(ti))
30             return false;
31
32         // create pattern
33         auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
34         auto pattern_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 1});
35         auto squeeze = std::make_shared<ngraph::opset5::Reshape>(data, pattern_1, false);
36         auto input_H_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
37         auto input_C_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
38         auto input_W = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
39         auto input_R = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
40         auto input_B = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{4});
41
42         auto cell = std::make_shared<ngraph::opset5::LSTMCell>(squeeze, input_H_state, input_C_state,
43                                                                input_W, input_R, input_B, 1);
44         auto pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
45         auto unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell, pattern_2, false);
46         ngraph::pattern::Matcher matcher(unsqueeze);
47
48         bool match = false;
49         auto func = ti->get_body();
50         for (const auto& res : func->get_results()) {
51             match = matcher.match((res->get_input_source_output(0)));
52             if (match)
53                 break;
54         }
55
56         // support for opset1::LSTMCell
57         auto cell_v1 = std::make_shared<ngraph::opset1::LSTMCell>(squeeze, input_H_state, input_C_state,
58                                                                  input_W, input_R, input_B, 1);
59         if (!match) {
60             unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell_v1, pattern_2, false);
61             matcher.clear_state();
62             matcher.m_pattern_node = unsqueeze;
63             for (const auto& res : func->get_results()) {
64                 match = matcher.match((res->get_input_source_output(0)));
65                 if (match)
66                     break;
67             }
68         }
69
70         // All nodes are in the TI body should be matched in pattern
71         if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
72             return false;
73
74         auto pattern_map = matcher.get_pattern_map();
75         std::shared_ptr<Node>& found_cell = pattern_map[cell];
76         if (!found_cell)
77             found_cell = pattern_map[cell_v1];
78
79         auto params = func->get_parameters();
80         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::InputDescription>> ordered_in_descs(3);
81         int64_t stride = 0, slice_axis = 0;
82         size_t batch_size = 0;
83         for (const auto& input_desc : ti->get_input_descriptions()) {
84             auto param = params[input_desc->m_body_parameter_index];
85             if (param == pattern_map[data]) {
86                 // to get batch size value
87                 if (param->get_partial_shape().is_dynamic()) {
88                     return false;
89                 }
90                 auto slice_input
91                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::SliceInputDescription>(input_desc);
92                 if (!slice_input)
93                     return false;
94
95                 stride = slice_input->m_stride;
96                 slice_axis = slice_input->m_axis;
97
98                 if (!(slice_axis == 0 || slice_axis == 1)) {
99                     return false;
100                 }
101                 batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
102                 ordered_in_descs[0] = input_desc;
103             } else if (param == pattern_map[input_H_state]) {
104                 ordered_in_descs[1] = input_desc;
105             } else if (param == pattern_map[input_C_state]) {
106                 ordered_in_descs[2] = input_desc;
107             } else {
108                 return false;
109             }
110         }
111
112         auto results = func->get_results();
113         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::OutputDescription>> ordered_out_descs(3);
114         for (const auto& output_desc : ti->get_output_descriptions()) {
115             std::shared_ptr<opset5::Result> res = results[output_desc->m_body_value_index];
116             if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
117                 auto concat_output
118                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::ConcatOutputDescription>(output_desc);
119                 if (!concat_output)
120                     return false;
121
122                 stride = concat_output->m_stride;
123                 ordered_out_descs[0] = output_desc;
124             } else if (res->get_input_source_output(0) == found_cell->output(0)) {
125                 ordered_out_descs[1] = output_desc;
126             } else if (res->get_input_source_output(0) == found_cell->output(1)) {
127                 ordered_out_descs[2] = output_desc;
128             } else {
129                 return false;
130             }
131         }
132
133         auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
134         const auto& lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(found_cell);
135         auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
136         if (slice_axis == 0) {
137             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
138             in_0 = std::make_shared<ngraph::opset5::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
139         }
140         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
141         auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
142         auto in_2 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[2]->m_input_index], axis_1);
143
144         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
145         auto in_4 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
146         auto in_5 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
147         auto in_6 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
148         auto sequence = std::make_shared<opset5::LSTMSequence>(
149                 in_0,
150                 in_1,
151                 in_2,
152                 seq_lengths,
153                 in_4,
154                 in_5,
155                 in_6,
156                 lstm_cell->get_hidden_size(),
157                 stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
158                 lstm_cell->get_activations_alpha(),
159                 lstm_cell->get_activations_beta(),
160                 lstm_cell->get_activations(),
161                 lstm_cell->get_clip());
162
163         auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
164         Output<Node> out = sequence->output(0);
165         if (slice_axis == 0) {
166             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 1, 0, 3});
167             out = std::make_shared<ngraph::opset5::Transpose>(out, order);
168         }
169         auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(out, axis_out);
170         auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(1), axis_out);
171         auto out_2 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(2), axis_out);
172
173         ngraph::NodeVector outputs = {out_0, out_1, out_2};
174         for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
175             if (ordered_out_descs[i]) {
176                 for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
177                     input.replace_source_output(outputs[i]->output(0));
178                 }
179                 outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
180             }
181         }
182
183         ngraph::OutputVector new_nodes = {in_1, in_2, in_4, in_5, in_6, sequence->output(0), out_0, out_1, out_2};
184         if (slice_axis == 0) {
185             new_nodes.push_back(out);
186             new_nodes.push_back(in_0.get_node_shared_ptr());
187         }
188         copy_runtime_info(ti, as_node_vector(new_nodes));
189         return true;
190     };
191
192     auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToLSTMSequence");
193     register_matcher(m, callback);
194 }
195
196 ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
197     auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
198                                                                         ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
199     ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
200         auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
201         if (!ti || m_transformation_callback(ti))
202             return false;
203
204         // create pattern
205         auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
206         auto pattern_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 1});
207         auto squeeze = std::make_shared<ngraph::opset5::Reshape>(data, pattern_1, false);
208
209         auto input_H_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
210         auto input_W = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
211         auto input_R = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
212         auto input_B = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1});
213
214         auto cell = std::make_shared<ngraph::opset5::RNNCell>(squeeze, input_H_state, input_W, input_R, input_B, 1);
215
216         auto pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
217         auto unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell, pattern_2, false);
218         ngraph::pattern::Matcher matcher(unsqueeze);
219
220         bool match = false;
221         auto func = ti->get_body();
222         for (const auto& res : func->get_results()) {
223             match = matcher.match((res->get_input_source_output(0)));
224             if (match)
225                 break;
226         }
227
228         // All nodes are in the TI body should be matched in pattern
229         if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
230             return false;
231
232         auto pattern_map = matcher.get_pattern_map();
233
234         auto params = func->get_parameters();
235         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::InputDescription>> ordered_in_descs(3);
236         int64_t stride = 0, slice_axis = 0;
237         size_t batch_size = 0;
238         for (const auto& input_desc : ti->get_input_descriptions()) {
239             auto param = params[input_desc->m_body_parameter_index];
240             if (param == pattern_map[data]) {
241                 // to get batch size value
242                 if (param->get_partial_shape().is_dynamic()) {
243                     return false;
244                 }
245                 auto slice_input
246                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::SliceInputDescription>(input_desc);
247                 if (!slice_input)
248                     return false;
249
250                 stride = slice_input->m_stride;
251                 slice_axis = slice_input->m_axis;
252                 if (!(slice_axis == 0 || slice_axis == 1)) {
253                     return false;
254                 }
255                 batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
256                 ordered_in_descs[0] = input_desc;
257             } else if (param == pattern_map[input_H_state]) {
258                 ordered_in_descs[1] = input_desc;
259             } else {
260                 return false;
261             }
262         }
263
264         auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
265
266         auto results = func->get_results();
267         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::OutputDescription>> ordered_out_descs(2);
268         for (const auto& output_desc : ti->get_output_descriptions()) {
269             std::shared_ptr<opset5::Result> res = results[output_desc->m_body_value_index];
270             if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
271                 auto concat_output
272                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::ConcatOutputDescription>(output_desc);
273                 if (!concat_output)
274                     return false;
275
276                 stride = concat_output->m_stride;
277                 ordered_out_descs[0] = output_desc;
278             } else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
279                 ordered_out_descs[1] = output_desc;
280             } else {
281                 return false;
282             }
283         }
284
285         const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset5::RNNCell>(pattern_map[cell]);
286
287         auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
288         if (slice_axis == 0) {
289             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
290             in_0 = std::make_shared<ngraph::opset5::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
291         }
292
293         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
294         auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
295
296         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
297         auto in_3 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
298         auto in_4 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
299         auto in_5 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
300         auto sequence = std::make_shared<opset5::RNNSequence>(
301                 in_0,
302                 in_1,
303                 seq_lengths,
304                 in_3,
305                 in_4,
306                 in_5,
307                 rnn_cell->get_hidden_size(),
308                 stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
309                 rnn_cell->get_activations(),
310                 rnn_cell->get_activations_alpha(),
311                 rnn_cell->get_activations_beta(),
312                 rnn_cell->get_clip());
313
314         auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
315         auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(1), axis_out);
316
317         Output<Node> out = sequence->output(0);
318         if (slice_axis == 0) {
319             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 1, 0, 3});
320             out = std::make_shared<ngraph::opset5::Transpose>(out, order);
321         }
322         auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(out, axis_out);
323
324         ngraph::NodeVector outputs = {out_0, out_1};
325         for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
326             if (ordered_out_descs[i]) {
327                 for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
328                     input.replace_source_output(outputs[i]->output(0));
329                 }
330                 outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
331             }
332         }
333
334         ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
335         if (slice_axis == 0) {
336             new_nodes.push_back(out);
337             new_nodes.push_back(in_0);
338         }
339         copy_runtime_info(ti, as_node_vector(new_nodes));
340         return true;
341     };
342
343     auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToRNNSequence");
344     register_matcher(m, callback);
345 }
346
347 ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
348     auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
349                                                                         ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
350     ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
351         auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
352         if (!ti || m_transformation_callback(ti))
353             return false;
354
355         // create pattern
356         auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
357         auto pattern_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 1});
358         auto squeeze = std::make_shared<ngraph::opset5::Reshape>(data, pattern_1, false);
359
360         auto input_H_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
361         auto input_W = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
362         auto input_R = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
363         auto input_B = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{3});
364
365         auto cell = std::make_shared<ngraph::opset5::GRUCell>(squeeze, input_H_state, input_W, input_R, input_B, 1);
366
367         auto pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
368         auto unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell, pattern_2, false);
369         ngraph::pattern::Matcher matcher(unsqueeze);
370
371         bool match = false;
372         auto func = ti->get_body();
373         for (const auto& res : func->get_results()) {
374             match = matcher.match((res->get_input_source_output(0)));
375             if (match)
376                 break;
377         }
378
379         // All nodes are in the TI body should be matched in pattern
380         if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
381             return false;
382
383         auto pattern_map = matcher.get_pattern_map();
384
385         auto params = func->get_parameters();
386         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::InputDescription>> ordered_in_descs(3);
387         int64_t stride = 0, slice_axis = 0;
388         size_t batch_size = 0;
389         for (const auto& input_desc : ti->get_input_descriptions()) {
390             auto param = params[input_desc->m_body_parameter_index];
391             if (param == pattern_map[data]) {
392                 // to get batch size value
393                 if (param->get_partial_shape().is_dynamic()) {
394                     return false;
395                 }
396                 auto slice_input
397                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::SliceInputDescription>(input_desc);
398                 if (!slice_input)
399                     return false;
400
401                 stride = slice_input->m_stride;
402                 slice_axis = slice_input->m_axis;
403                 if (!(slice_axis == 0 || slice_axis == 1)) {
404                     return false;
405                 }
406                 batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
407                 ordered_in_descs[0] = input_desc;
408             } else if (param == pattern_map[input_H_state]) {
409                 ordered_in_descs[1] = input_desc;
410             } else {
411                 return false;
412             }
413         }
414
415         auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
416
417         auto results = func->get_results();
418         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::OutputDescription>> ordered_out_descs(2);
419         for (const auto& output_desc : ti->get_output_descriptions()) {
420             std::shared_ptr<opset5::Result> res = results[output_desc->m_body_value_index];
421             if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
422                 auto concat_output
423                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::ConcatOutputDescription>(output_desc);
424                 if (!concat_output)
425                     return false;
426
427                 stride = concat_output->m_stride;
428                 ordered_out_descs[0] = output_desc;
429             } else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
430                 ordered_out_descs[1] = output_desc;
431             } else {
432                 return false;
433             }
434         }
435
436         const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset5::GRUCell>(pattern_map[cell]);
437
438         auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
439         if (slice_axis == 0) {
440             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
441             in_0 = std::make_shared<ngraph::opset5::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
442         }
443
444         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
445         auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
446
447         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
448         auto in_3 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
449         auto in_4 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
450         auto in_5 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
451         auto sequence = std::make_shared<opset5::GRUSequence>(
452                 in_0,
453                 in_1,
454                 seq_lengths,
455                 in_3,
456                 in_4,
457                 in_5,
458                 rnn_cell->get_hidden_size(),
459                 stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
460                 rnn_cell->get_activations(),
461                 rnn_cell->get_activations_alpha(),
462                 rnn_cell->get_activations_beta(),
463                 rnn_cell->get_clip(),
464                 rnn_cell->get_linear_before_reset());
465
466         auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
467         auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(1), axis_out);
468
469         Output<Node> out = sequence->output(0);
470         if (slice_axis == 0) {
471             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 1, 0, 3});
472             out = std::make_shared<ngraph::opset5::Transpose>(out, order);
473         }
474         auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(out, axis_out);
475
476         ngraph::NodeVector outputs = {out_0, out_1};
477         for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
478             if (ordered_out_descs[i]) {
479                 for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
480                     input.replace_source_output(outputs[i]->output(0));
481                 }
482                 outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
483             }
484         }
485
486         ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
487         if (slice_axis == 0) {
488             new_nodes.push_back(out);
489             new_nodes.push_back(in_0);
490         }
491         copy_runtime_info(ti, as_node_vector(new_nodes));
492         return true;
493     };
494
495     auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToGRUSequence");
496     register_matcher(m, callback);
497 }