Fixed static analysis issues (transformations) (#3276)
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / op_conversions / convert_ti_to_sequences.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "transformations/op_conversions/convert_ti_to_sequences.hpp"
6 #include "transformations/utils/utils.hpp"
7
8 #include <memory>
9 #include <vector>
10
11 #include <ngraph/node.hpp>
12 #include <ngraph/pass/manager.hpp>
13 #include <ngraph/opsets/opset5.hpp>
14 #include <ngraph/opsets/opset1.hpp>
15 #include <ngraph/rt_info.hpp>
16 #include <ngraph/graph_util.hpp>
17 #include <ngraph/specialize_function.hpp>
18 #include <ngraph/pattern/op/wrap_type.hpp>
19
20 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToLSTMSequence, "ConvertTensorIteratorToLSTMSequence", 0);
21 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToRNNSequence, "ConvertTensorIteratorToRNNSequence", 0);
22 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToGRUSequence, "ConvertTensorIteratorToGRUSequence", 0);
23
24 ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
25     auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
26                                                                         ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
27     ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
28         auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
29         if (!ti || m_transformation_callback(ti))
30             return false;
31
32         // create pattern
33         auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
34         auto pattern_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 1});
35         auto squeeze = std::make_shared<ngraph::opset5::Reshape>(data, pattern_1, false);
36         auto input_H_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
37         auto input_C_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
38         auto input_W = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
39         auto input_R = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{4, 1});
40         auto input_B = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{4});
41
42         auto cell = std::make_shared<ngraph::opset5::LSTMCell>(squeeze, input_H_state, input_C_state,
43                                                                input_W, input_R, input_B, 1);
44         auto pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
45         auto unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell, pattern_2, false);
46         ngraph::pattern::Matcher matcher(unsqueeze);
47
48         bool match = false;
49         auto func = ti->get_body();
50         for (const auto& res : func->get_results()) {
51             match = matcher.match((res->get_input_source_output(0)));
52             if (match)
53                 break;
54         }
55
56         // support for opset1::LSTMCell
57         auto cell_v1 = std::make_shared<ngraph::opset1::LSTMCell>(squeeze, input_H_state, input_C_state,
58                                                                  input_W, input_R, input_B, 1);
59         if (!match) {
60             unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell_v1, pattern_2, false);
61             matcher.clear_state();
62             matcher.m_pattern_node = unsqueeze;
63             for (const auto& res : func->get_results()) {
64                 match = matcher.match((res->get_input_source_output(0)));
65                 if (match)
66                     break;
67             }
68         }
69
70         // All nodes are in the TI body should be matched in pattern
71         if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
72             return false;
73
74         auto pattern_map = matcher.get_pattern_map();
75         std::shared_ptr<Node>& found_cell = pattern_map[cell];
76         if (!found_cell)
77             found_cell = pattern_map[cell_v1];
78
79         auto params = func->get_parameters();
80         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::InputDescription>> ordered_in_descs(3);
81         int64_t stride = 0, slice_axis = 0;
82         size_t batch_size = 0;
83         for (const auto& input_desc : ti->get_input_descriptions()) {
84             auto param = params[input_desc->m_body_parameter_index];
85             if (param == pattern_map[data]) {
86                 // to get batch size value
87                 if (param->get_partial_shape().is_dynamic()) {
88                     return false;
89                 }
90                 auto slice_input
91                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::SliceInputDescription>(input_desc);
92                 if (!slice_input)
93                     return false;
94
95                 stride = slice_input->m_stride;
96                 slice_axis = slice_input->m_axis;
97
98                 if (!(slice_axis == 0 || slice_axis == 1)) {
99                     return false;
100                 }
101                 batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
102                 ordered_in_descs[0] = input_desc;
103             } else if (param == pattern_map[input_H_state]) {
104                 ordered_in_descs[1] = input_desc;
105             } else if (param == pattern_map[input_C_state]) {
106                 ordered_in_descs[2] = input_desc;
107             } else {
108                 return false;
109             }
110         }
111
112         auto results = func->get_results();
113         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::OutputDescription>> ordered_out_descs(3);
114         for (const auto& output_desc : ti->get_output_descriptions()) {
115             std::shared_ptr<opset5::Result> res = results[output_desc->m_body_value_index];
116             if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
117                 auto concat_output
118                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::ConcatOutputDescription>(output_desc);
119                 if (!concat_output)
120                     return false;
121
122                 stride = concat_output->m_stride;
123                 ordered_out_descs[0] = output_desc;
124             } else if (res->get_input_source_output(0) == found_cell->output(0)) {
125                 ordered_out_descs[1] = output_desc;
126             } else if (res->get_input_source_output(0) == found_cell->output(1)) {
127                 ordered_out_descs[2] = output_desc;
128             } else {
129                 return false;
130             }
131         }
132
133         auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
134         const auto& lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(found_cell);
135         if (lstm_cell == nullptr)
136             return false;
137         auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
138         if (slice_axis == 0) {
139             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
140             in_0 = std::make_shared<ngraph::opset5::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
141         }
142         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
143         auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
144         auto in_2 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[2]->m_input_index], axis_1);
145
146         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
147         auto in_4 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
148         auto in_5 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
149         auto in_6 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
150         auto sequence = std::make_shared<opset5::LSTMSequence>(
151                 in_0,
152                 in_1,
153                 in_2,
154                 seq_lengths,
155                 in_4,
156                 in_5,
157                 in_6,
158                 lstm_cell->get_hidden_size(),
159                 stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
160                 lstm_cell->get_activations_alpha(),
161                 lstm_cell->get_activations_beta(),
162                 lstm_cell->get_activations(),
163                 lstm_cell->get_clip());
164
165         auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
166         Output<Node> out = sequence->output(0);
167         if (slice_axis == 0) {
168             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 1, 0, 3});
169             out = std::make_shared<ngraph::opset5::Transpose>(out, order);
170         }
171         auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(out, axis_out);
172         auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(1), axis_out);
173         auto out_2 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(2), axis_out);
174
175         ngraph::NodeVector outputs = {out_0, out_1, out_2};
176         for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
177             if (ordered_out_descs[i]) {
178                 for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
179                     input.replace_source_output(outputs[i]->output(0));
180                 }
181                 outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
182             }
183         }
184
185         ngraph::OutputVector new_nodes = {in_1, in_2, in_4, in_5, in_6, sequence->output(0), out_0, out_1, out_2};
186         if (slice_axis == 0) {
187             new_nodes.push_back(out);
188             new_nodes.push_back(in_0.get_node_shared_ptr());
189         }
190         copy_runtime_info(ti, as_node_vector(new_nodes));
191         return true;
192     };
193
194     auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToLSTMSequence");
195     register_matcher(m, callback);
196 }
197
198 ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
199     auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
200                                                                         ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
201     ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
202         auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
203         if (!ti || m_transformation_callback(ti))
204             return false;
205
206         // create pattern
207         auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
208         auto pattern_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 1});
209         auto squeeze = std::make_shared<ngraph::opset5::Reshape>(data, pattern_1, false);
210
211         auto input_H_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
212         auto input_W = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
213         auto input_R = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1, 1});
214         auto input_B = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{1});
215
216         auto cell = std::make_shared<ngraph::opset5::RNNCell>(squeeze, input_H_state, input_W, input_R, input_B, 1);
217
218         auto pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
219         auto unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell, pattern_2, false);
220         ngraph::pattern::Matcher matcher(unsqueeze);
221
222         bool match = false;
223         auto func = ti->get_body();
224         for (const auto& res : func->get_results()) {
225             match = matcher.match((res->get_input_source_output(0)));
226             if (match)
227                 break;
228         }
229
230         // All nodes are in the TI body should be matched in pattern
231         if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
232             return false;
233
234         auto pattern_map = matcher.get_pattern_map();
235
236         auto params = func->get_parameters();
237         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::InputDescription>> ordered_in_descs(3);
238         int64_t stride = 0, slice_axis = 0;
239         size_t batch_size = 0;
240         for (const auto& input_desc : ti->get_input_descriptions()) {
241             auto param = params[input_desc->m_body_parameter_index];
242             if (param == pattern_map[data]) {
243                 // to get batch size value
244                 if (param->get_partial_shape().is_dynamic()) {
245                     return false;
246                 }
247                 auto slice_input
248                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::SliceInputDescription>(input_desc);
249                 if (!slice_input)
250                     return false;
251
252                 stride = slice_input->m_stride;
253                 slice_axis = slice_input->m_axis;
254                 if (!(slice_axis == 0 || slice_axis == 1)) {
255                     return false;
256                 }
257                 batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
258                 ordered_in_descs[0] = input_desc;
259             } else if (param == pattern_map[input_H_state]) {
260                 ordered_in_descs[1] = input_desc;
261             } else {
262                 return false;
263             }
264         }
265
266         auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
267
268         auto results = func->get_results();
269         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::OutputDescription>> ordered_out_descs(2);
270         for (const auto& output_desc : ti->get_output_descriptions()) {
271             std::shared_ptr<opset5::Result> res = results[output_desc->m_body_value_index];
272             if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
273                 auto concat_output
274                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::ConcatOutputDescription>(output_desc);
275                 if (!concat_output)
276                     return false;
277
278                 stride = concat_output->m_stride;
279                 ordered_out_descs[0] = output_desc;
280             } else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
281                 ordered_out_descs[1] = output_desc;
282             } else {
283                 return false;
284             }
285         }
286
287         const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset5::RNNCell>(pattern_map[cell]);
288         if (rnn_cell == nullptr)
289             return false;
290
291         auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
292         if (slice_axis == 0) {
293             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
294             in_0 = std::make_shared<ngraph::opset5::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
295         }
296
297         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
298         auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
299
300         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
301         auto in_3 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
302         auto in_4 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
303         auto in_5 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
304         auto sequence = std::make_shared<opset5::RNNSequence>(
305                 in_0,
306                 in_1,
307                 seq_lengths,
308                 in_3,
309                 in_4,
310                 in_5,
311                 rnn_cell->get_hidden_size(),
312                 stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
313                 rnn_cell->get_activations(),
314                 rnn_cell->get_activations_alpha(),
315                 rnn_cell->get_activations_beta(),
316                 rnn_cell->get_clip());
317
318         auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
319         auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(1), axis_out);
320
321         Output<Node> out = sequence->output(0);
322         if (slice_axis == 0) {
323             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 1, 0, 3});
324             out = std::make_shared<ngraph::opset5::Transpose>(out, order);
325         }
326         auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(out, axis_out);
327
328         ngraph::NodeVector outputs = {out_0, out_1};
329         for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
330             if (ordered_out_descs[i]) {
331                 for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
332                     input.replace_source_output(outputs[i]->output(0));
333                 }
334                 outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
335             }
336         }
337
338         ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
339         if (slice_axis == 0) {
340             new_nodes.push_back(out);
341             new_nodes.push_back(in_0);
342         }
343         copy_runtime_info(ti, as_node_vector(new_nodes));
344         return true;
345     };
346
347     auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToRNNSequence");
348     register_matcher(m, callback);
349 }
350
351 ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
352     auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
353                                                                         ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
354     ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
355         auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
356         if (!ti || m_transformation_callback(ti))
357             return false;
358
359         // create pattern
360         auto data = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 1});
361         auto pattern_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 1});
362         auto squeeze = std::make_shared<ngraph::opset5::Reshape>(data, pattern_1, false);
363
364         auto input_H_state = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1});
365         auto input_W = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
366         auto input_R = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{3, 1});
367         auto input_B = std::make_shared<ngraph::opset5::Constant>(ngraph::element::f32, ngraph::Shape{3});
368
369         auto cell = std::make_shared<ngraph::opset5::GRUCell>(squeeze, input_H_state, input_W, input_R, input_B, 1);
370
371         auto pattern_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 1, 1});
372         auto unsqueeze = std::make_shared<ngraph::opset5::Reshape>(cell, pattern_2, false);
373         ngraph::pattern::Matcher matcher(unsqueeze);
374
375         bool match = false;
376         auto func = ti->get_body();
377         for (const auto& res : func->get_results()) {
378             match = matcher.match((res->get_input_source_output(0)));
379             if (match)
380                 break;
381         }
382
383         // All nodes are in the TI body should be matched in pattern
384         if (!match || (matcher.get_matched_nodes().size() + func->get_results().size()) != func->get_ops().size())
385             return false;
386
387         auto pattern_map = matcher.get_pattern_map();
388
389         auto params = func->get_parameters();
390         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::InputDescription>> ordered_in_descs(3);
391         int64_t stride = 0, slice_axis = 0;
392         size_t batch_size = 0;
393         for (const auto& input_desc : ti->get_input_descriptions()) {
394             auto param = params[input_desc->m_body_parameter_index];
395             if (param == pattern_map[data]) {
396                 // to get batch size value
397                 if (param->get_partial_shape().is_dynamic()) {
398                     return false;
399                 }
400                 auto slice_input
401                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::SliceInputDescription>(input_desc);
402                 if (!slice_input)
403                     return false;
404
405                 stride = slice_input->m_stride;
406                 slice_axis = slice_input->m_axis;
407                 if (!(slice_axis == 0 || slice_axis == 1)) {
408                     return false;
409                 }
410                 batch_size = param->get_shape()[slice_axis == 0 ? 1 : 0];
411                 ordered_in_descs[0] = input_desc;
412             } else if (param == pattern_map[input_H_state]) {
413                 ordered_in_descs[1] = input_desc;
414             } else {
415                 return false;
416             }
417         }
418
419         auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{batch_size}, {ti->get_num_iterations()});
420
421         auto results = func->get_results();
422         std::vector<std::shared_ptr<ngraph::opset5::TensorIterator::OutputDescription>> ordered_out_descs(2);
423         for (const auto& output_desc : ti->get_output_descriptions()) {
424             std::shared_ptr<opset5::Result> res = results[output_desc->m_body_value_index];
425             if (res->get_input_source_output(0) == pattern_map[unsqueeze]) {
426                 auto concat_output
427                         = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator::ConcatOutputDescription>(output_desc);
428                 if (!concat_output)
429                     return false;
430
431                 stride = concat_output->m_stride;
432                 ordered_out_descs[0] = output_desc;
433             } else if (res->get_input_source_output(0) == pattern_map[cell]->output(0)) {
434                 ordered_out_descs[1] = output_desc;
435             } else {
436                 return false;
437             }
438         }
439
440         const auto& rnn_cell = std::dynamic_pointer_cast<ngraph::opset5::GRUCell>(pattern_map[cell]);
441         if (rnn_cell == nullptr)
442             return false;
443
444         auto in_0 = ti->input_values()[ordered_in_descs[0]->m_input_index];
445         if (slice_axis == 0) {
446             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {1, 0, 2});
447             in_0 = std::make_shared<ngraph::opset5::Transpose>(ti->input_values()[ordered_in_descs[0]->m_input_index], order);
448         }
449
450         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
451         auto in_1 = std::make_shared<ngraph::opset5::Unsqueeze>(ti->input_values()[ordered_in_descs[1]->m_input_index], axis_1);
452
453         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
454         auto in_3 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_W]->output(0).get_node_shared_ptr(), axis_2);
455         auto in_4 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_R]->output(0).get_node_shared_ptr(), axis_2);
456         auto in_5 = std::make_shared<ngraph::opset5::Unsqueeze>(pattern_map[input_B]->output(0).get_node_shared_ptr(), axis_2);
457         auto sequence = std::make_shared<opset5::GRUSequence>(
458                 in_0,
459                 in_1,
460                 seq_lengths,
461                 in_3,
462                 in_4,
463                 in_5,
464                 rnn_cell->get_hidden_size(),
465                 stride > 0 ? ngraph::op::RecurrentSequenceDirection::FORWARD: ngraph::op::RecurrentSequenceDirection::REVERSE,
466                 rnn_cell->get_activations(),
467                 rnn_cell->get_activations_alpha(),
468                 rnn_cell->get_activations_beta(),
469                 rnn_cell->get_clip(),
470                 rnn_cell->get_linear_before_reset());
471
472         auto axis_out = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
473         auto out_1 = std::make_shared<ngraph::opset5::Squeeze>(sequence->output(1), axis_out);
474
475         Output<Node> out = sequence->output(0);
476         if (slice_axis == 0) {
477             auto order = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {2, 1, 0, 3});
478             out = std::make_shared<ngraph::opset5::Transpose>(out, order);
479         }
480         auto out_0 = std::make_shared<ngraph::opset5::Squeeze>(out, axis_out);
481
482         ngraph::NodeVector outputs = {out_0, out_1};
483         for (size_t i = 0; i < ordered_out_descs.size(); ++i) {
484             if (ordered_out_descs[i]) {
485                 for (const auto &input : ti->output(ordered_out_descs[i]->m_output_index).get_target_inputs()) {
486                     input.replace_source_output(outputs[i]->output(0));
487                 }
488                 outputs[i]->get_output_tensor(0).set_name(op::util::create_ie_output_name(ti->output(ordered_out_descs[i]->m_output_index)));
489             }
490         }
491
492         ngraph::OutputVector new_nodes = {in_1, in_3, in_4, in_5, sequence, out_0, out_1};
493         if (slice_axis == 0) {
494             new_nodes.push_back(out);
495             new_nodes.push_back(in_0);
496         }
497         copy_runtime_info(ti, as_node_vector(new_nodes));
498         return true;
499     };
500
501     auto m = std::make_shared<ngraph::pattern::Matcher>(tensor_iterator, "ConvertTensorIteratorToGRUSequence");
502     register_matcher(m, callback);
503 }