Fixed static analysis issues (transformations) (#3276)
[platform/upstream/dldt.git] / inference-engine / src / legacy_api / src / transformations / convert_opset1_to_legacy / convert_sequences_to_sequences_ie.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "legacy/transformations/convert_opset1_to_legacy/convert_sequences_to_sequences_ie.hpp"
6
7 #include <memory>
8
9 #include <ngraph/opsets/opset5.hpp>
10 #include <ngraph/rt_info.hpp>
11 #include <ngraph/pattern/op/wrap_type.hpp>
12
13 #include <legacy/ngraph_ops/lstm_sequence_ie.hpp>
14 #include <legacy/ngraph_ops/gru_sequence_ie.hpp>
15 #include <legacy/ngraph_ops/rnn_sequence_ie.hpp>
16
17 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertLSTMSequenceMatcher, "ConvertLSTMSequenceMatcher", 0);
18 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGRUSequenceMatcher, "ConvertGRUSequenceMatcher", 0);
19 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertRNNSequenceMatcher, "ConvertRNNSequenceMatcher", 0);
20
21 namespace {
22     int64_t get_seq_axis(const std::shared_ptr<ngraph::Node>& sequence_node) {
23         // Optimization.
24         // Plug-ins support seq_axis attribute (value 1 or 0) for Seq ops, but according to the spec we don't
25         // support this attribute and should insert Transpose layer before and after Seq op in TI to Sequences
26         // transformation. Additional Transpose layers affect the performance, so we try to detect pattern
27         // Transpose(axis_order={1,0,2}) -> Seq -> Transpose(axis_order={2,1,0,3}
28         // and replace unnecessary Transpose ops with SeqIE (seq_axis = 0) to transfer value
29         // of the attribute to plug-ins.
30         // todo: specify seq_axis attribute for Sequence ops.
31         int64_t seq_axis = 1; // default
32         const auto& target_inputs = sequence_node->output(0).get_target_inputs();
33         if (target_inputs.size() == 1) {
34             const auto& transpose_before = std::dynamic_pointer_cast<ngraph::opset5::Transpose>(sequence_node->input_value(0).get_node_shared_ptr());
35             const auto& transpose_after = std::dynamic_pointer_cast<ngraph::opset5::Transpose>(target_inputs.begin()->get_node()->shared_from_this());
36             if (transpose_after != nullptr && transpose_before != nullptr) {
37                 auto order_before = std::dynamic_pointer_cast<ngraph::opset5::Constant>(
38                         transpose_before->input_value(1).get_node_shared_ptr());
39                 auto order_after = std::dynamic_pointer_cast<ngraph::opset5::Constant>(
40                         transpose_after->input_value(1).get_node_shared_ptr());
41                 if (order_before != nullptr && order_after != nullptr) {
42                     auto order_before_values = order_before->cast_vector<int64_t>();
43                     auto order_after_values = order_after->cast_vector<int64_t>();
44                     std::vector<int64_t> order_ref_before = {1, 0, 2};
45                     std::vector<int64_t> order_ref_after = {2, 1, 0, 3};
46                     if (order_before_values == order_ref_before && order_after_values == order_ref_after) {
47                         seq_axis = 0;
48                     }
49                 }
50             }
51         }
52         return seq_axis;
53     }
54 } // namespace
55
56 ngraph::pass::ConvertLSTMSequenceMatcher::ConvertLSTMSequenceMatcher() {
57     auto lstm_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::opset5::LSTMSequence>();
58
59     ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
60         auto lstm_sequence = std::dynamic_pointer_cast<ngraph::opset5::LSTMSequence>(m.get_match_root());
61         if (!lstm_sequence) {
62             return false;
63         }
64
65         const auto& W = lstm_sequence->input_value(4);
66         const auto& R = lstm_sequence->input_value(5);
67
68         // Bidirectional cases are not supported
69         if (lstm_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
70             return false;
71
72         // Detect pattern: Transpose_before -> Seq -> Transpose_after
73         auto seq_axis = get_seq_axis(lstm_sequence);
74         ngraph::Output<ngraph::Node> in_0 = lstm_sequence->input(0).get_source_output();
75         if (seq_axis == 0) {
76             // input(0) to Transpose_before
77             in_0 = lstm_sequence->get_input_source_output(0).get_node_shared_ptr()->get_input_source_output(0);
78         }
79         // for forward/reverse cases we can squeeze num_direction dimension
80         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
81         auto in_1 = std::make_shared<ngraph::opset5::Squeeze>(lstm_sequence->input_value(1), axis_1);
82         auto in_2 = std::make_shared<ngraph::opset5::Squeeze>(lstm_sequence->input_value(2), axis_1);
83         auto concat = std::make_shared<ngraph::opset5::Concat>(ngraph::OutputVector{W, R}, 2);
84         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
85         auto in_3 = std::make_shared<ngraph::opset5::Squeeze>(concat->output(0), axis_2);
86         auto in_4 = std::make_shared<ngraph::opset5::Squeeze>(lstm_sequence->input_value(6), axis_2);
87         auto lstm_sequence_ie = std::make_shared<ngraph::op::LSTMSequenceIE>(
88                 in_0,  // X
89                 in_1,  // initial_hidden_state
90                 in_2,  // initial_cell_state
91                 lstm_sequence->input_value(3),
92                 in_3,  // WR
93                 in_4,  // B
94                 lstm_sequence->get_hidden_size(),
95                 lstm_sequence->get_direction(),
96                 lstm_sequence->get_activations(),
97                 lstm_sequence->get_activations_alpha(),
98                 lstm_sequence->get_activations_beta(),
99                 lstm_sequence->get_clip(),
100                 seq_axis);
101
102         auto unsqueeze_axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
103         auto unsqueeze_1 = std::make_shared<ngraph::opset5::Unsqueeze>(lstm_sequence_ie->output(0), unsqueeze_axis);
104         auto unsqueeze_2 = std::make_shared<ngraph::opset5::Unsqueeze>(lstm_sequence_ie->output(1), unsqueeze_axis);
105         auto unsqueeze_3 = std::make_shared<ngraph::opset5::Unsqueeze>(lstm_sequence_ie->output(2), unsqueeze_axis);
106
107         ngraph::copy_runtime_info(lstm_sequence, {concat, lstm_sequence_ie, in_1, in_2, in_3, in_4, unsqueeze_1,
108                                                   unsqueeze_2, unsqueeze_3});
109         unsqueeze_1->set_friendly_name(lstm_sequence->get_friendly_name()+".0");
110         unsqueeze_2->set_friendly_name(lstm_sequence->get_friendly_name()+".1");
111         unsqueeze_3->set_friendly_name(lstm_sequence->get_friendly_name()+".2");
112         if (seq_axis == 1) {
113             ngraph::replace_node(lstm_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
114         } else {
115             const auto &lstm_target_inputs = lstm_sequence->output(0).get_target_inputs();
116             if (lstm_target_inputs.empty())
117                 return false;
118             auto transpose_after = lstm_target_inputs.begin()->get_node()->shared_from_this();
119             ngraph::replace_node(transpose_after, unsqueeze_1);
120             ngraph::replace_node(lstm_sequence, {lstm_sequence_ie->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
121         }
122         return true;
123     };
124
125     auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_sequence_ngraph, "ConvertLSTMSequenceToLSTMSequenceIE");
126     this->register_matcher(m, callback);
127 }
128
129 ngraph::pass::ConvertGRUSequenceMatcher::ConvertGRUSequenceMatcher() {
130     auto gru_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::opset5::GRUSequence>();
131
132     ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
133         auto gru_sequence = std::dynamic_pointer_cast<ngraph::opset5::GRUSequence>(m.get_match_root());
134         if (!gru_sequence) {
135             return false;
136         }
137
138         auto W = gru_sequence->input_value(3);
139         auto R = gru_sequence->input_value(4);
140
141         // Bidirectional cases are not supported
142         if (gru_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
143             return false;
144
145         // Detect pattern: Transpose_before -> Seq -> Transpose_after
146         auto seq_axis = get_seq_axis(gru_sequence);
147         ngraph::Output<ngraph::Node> in_0 = gru_sequence->input(0).get_source_output();
148         if (seq_axis == 0) {
149             // input(0) to Transpose_before
150             in_0 = gru_sequence->get_input_source_output(0).get_node_shared_ptr()->get_input_source_output(0);
151         }
152         // for forward/reverse cases we can squeeze num_direction dimension
153         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
154         auto in_1 = std::make_shared<ngraph::opset5::Squeeze>(gru_sequence->input_value(1), axis_1);
155         auto concat = std::make_shared<ngraph::opset5::Concat>(ngraph::OutputVector{W, R}, 2);
156         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
157         auto in_3 = std::make_shared<ngraph::opset5::Squeeze>(concat->output(0), axis_2);
158         auto in_4 = std::make_shared<ngraph::opset5::Squeeze>(gru_sequence->input_value(5), axis_2);
159
160         auto gru_sequence_ie = std::make_shared<ngraph::op::GRUSequenceIE>(
161                 in_0, // X
162                 in_1,  // initial_hidden_state
163                 gru_sequence->input_value(2),
164                 in_3,  // WR
165                 in_4,  // B
166                 gru_sequence->get_hidden_size(),
167                 gru_sequence->get_direction(),
168                 gru_sequence->get_activations(),
169                 gru_sequence->get_activations_alpha(),
170                 gru_sequence->get_activations_beta(),
171                 gru_sequence->get_clip(),
172                 gru_sequence->get_linear_before_reset(),
173                 seq_axis);
174
175         auto unsqueeze_axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
176         auto unsqueeze_1 = std::make_shared<ngraph::opset5::Unsqueeze>(gru_sequence_ie->output(0), unsqueeze_axis);
177         auto unsqueeze_2 = std::make_shared<ngraph::opset5::Unsqueeze>(gru_sequence_ie->output(1), unsqueeze_axis);
178
179         ngraph::copy_runtime_info(gru_sequence, {concat, gru_sequence_ie, unsqueeze_1, unsqueeze_2, in_1, in_3, in_4});
180         unsqueeze_1->set_friendly_name(gru_sequence->get_friendly_name()+".0");
181         unsqueeze_2->set_friendly_name(gru_sequence->get_friendly_name()+".1");
182         if (seq_axis == 1) {
183             ngraph::replace_node(gru_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0)});
184         } else {
185             const auto &gru_target_inputs = gru_sequence->output(0).get_target_inputs();
186             if (gru_target_inputs.empty())
187                 return false;
188             auto transpose_after = gru_target_inputs.begin()->get_node()->shared_from_this();
189             ngraph::replace_node(transpose_after, unsqueeze_1);
190             ngraph::replace_node(gru_sequence, {gru_sequence_ie->output(0), unsqueeze_2->output(0)});
191         }
192         return true;
193     };
194
195     auto m = std::make_shared<ngraph::pattern::Matcher>(gru_sequence_ngraph, "ConvertGRUSequenceToGRUSequenceIE");
196     this->register_matcher(m, callback);
197 }
198
199 ngraph::pass::ConvertRNNSequenceMatcher::ConvertRNNSequenceMatcher() {
200     auto rnn_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::opset5::RNNSequence>();
201
202     ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
203         auto rnn_sequence = std::dynamic_pointer_cast<ngraph::opset5::RNNSequence>(m.get_match_root());
204         if (!rnn_sequence) {
205             return false;
206         }
207
208         // Bidirectional cases are not supported
209         if (rnn_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
210             return false;
211
212         // Detect pattern: Transpose_before -> Seq -> Transpose_after
213         auto seq_axis = get_seq_axis(rnn_sequence);
214         ngraph::Output<ngraph::Node> in_0 = rnn_sequence->input(0).get_source_output();
215         if (seq_axis == 0) {
216             // input(0) to Transpose_before
217             in_0 = rnn_sequence->get_input_source_output(0).get_node_shared_ptr()->get_input_source_output(0);
218         }
219
220         auto W = rnn_sequence->input_value(3);
221         auto R = rnn_sequence->input_value(4);
222
223         // for forward/reverse cases we can squeeze num_direction dimension
224         auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
225         auto in_1 = std::make_shared<ngraph::opset5::Squeeze>(rnn_sequence->input_value(1), axis_1);
226         auto concat = std::make_shared<ngraph::opset5::Concat>(ngraph::OutputVector{W, R}, 2);
227         auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
228         auto in_3 = std::make_shared<ngraph::opset5::Squeeze>(concat->output(0), axis_2);
229         auto in_4 = std::make_shared<ngraph::opset5::Squeeze>(rnn_sequence->input_value(5), axis_2);
230         auto rnn_sequence_ie = std::make_shared<ngraph::op::RNNSequenceIE>(
231                 in_0,  // X
232                 in_1,  // initial_hidden_state
233                 rnn_sequence->input_value(2),
234                 in_3,  // WR
235                 in_4,  // B
236                 rnn_sequence->get_hidden_size(),
237                 rnn_sequence->get_direction(),
238                 rnn_sequence->get_activations(),
239                 rnn_sequence->get_activations_alpha(),
240                 rnn_sequence->get_activations_beta(),
241                 rnn_sequence->get_clip(),
242                 seq_axis);
243
244         auto unsqueeze_axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
245         auto unsqueeze_1 = std::make_shared<ngraph::opset5::Unsqueeze>(rnn_sequence_ie->output(0), unsqueeze_axis);
246         auto unsqueeze_2 = std::make_shared<ngraph::opset5::Unsqueeze>(rnn_sequence_ie->output(1), unsqueeze_axis);
247
248         ngraph::copy_runtime_info(rnn_sequence, {concat, rnn_sequence_ie, in_1, in_3, in_4, unsqueeze_1,
249                                                  unsqueeze_2});
250         unsqueeze_1->set_friendly_name(rnn_sequence->get_friendly_name()+".0");
251         unsqueeze_2->set_friendly_name(rnn_sequence->get_friendly_name()+".1");
252
253         if (seq_axis == 1) {
254             ngraph::replace_node(rnn_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0)});
255         } else {
256             const auto &rnn_target_inputs = rnn_sequence->output(0).get_target_inputs();
257             if (rnn_target_inputs.empty())
258                 return false;
259             auto transpose_after = rnn_target_inputs.begin()->get_node()->shared_from_this();
260             ngraph::replace_node(transpose_after, unsqueeze_1);
261             ngraph::replace_node(rnn_sequence, {rnn_sequence_ie->output(0), unsqueeze_2->output(0)});
262         }
263         return true;
264     };
265
266     auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_sequence_ngraph, "ConvertRNNSequenceToRNNSequenceIE");
267     this->register_matcher(m, callback);
268 }