1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "legacy/transformations/convert_opset1_to_legacy/convert_sequences_to_sequences_ie.hpp"
9 #include <ngraph/opsets/opset5.hpp>
10 #include <ngraph/rt_info.hpp>
11 #include <ngraph/pattern/op/wrap_type.hpp>
13 #include <legacy/ngraph_ops/lstm_sequence_ie.hpp>
14 #include <legacy/ngraph_ops/gru_sequence_ie.hpp>
15 #include <legacy/ngraph_ops/rnn_sequence_ie.hpp>
17 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertLSTMSequenceMatcher, "ConvertLSTMSequenceMatcher", 0);
18 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGRUSequenceMatcher, "ConvertGRUSequenceMatcher", 0);
19 NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertRNNSequenceMatcher, "ConvertRNNSequenceMatcher", 0);
22 int64_t get_seq_axis(const std::shared_ptr<ngraph::Node>& sequence_node) {
24 // Plug-ins support seq_axis attribute (value 1 or 0) for Seq ops, but according to the spec we don't
25 // support this attribute and should insert Transpose layer before and after Seq op in TI to Sequences
26 // transformation. Additional Transpose layers affect the performance, so we try to detect pattern
27 // Transpose(axis_order={1,0,2}) -> Seq -> Transpose(axis_order={2,1,0,3}
28 // and replace unnecessary Transpose ops with SeqIE (seq_axis = 0) to transfer value
29 // of the attribute to plug-ins.
30 // todo: specify seq_axis attribute for Sequence ops.
31 int64_t seq_axis = 1; // default
32 const auto& target_inputs = sequence_node->output(0).get_target_inputs();
33 if (target_inputs.size() == 1) {
34 const auto& transpose_before = std::dynamic_pointer_cast<ngraph::opset5::Transpose>(sequence_node->input_value(0).get_node_shared_ptr());
35 const auto& transpose_after = std::dynamic_pointer_cast<ngraph::opset5::Transpose>(target_inputs.begin()->get_node()->shared_from_this());
36 if (transpose_after != nullptr && transpose_before != nullptr) {
37 auto order_before = std::dynamic_pointer_cast<ngraph::opset5::Constant>(
38 transpose_before->input_value(1).get_node_shared_ptr());
39 auto order_after = std::dynamic_pointer_cast<ngraph::opset5::Constant>(
40 transpose_after->input_value(1).get_node_shared_ptr());
41 if (order_before != nullptr && order_after != nullptr) {
42 auto order_before_values = order_before->cast_vector<int64_t>();
43 auto order_after_values = order_after->cast_vector<int64_t>();
44 std::vector<int64_t> order_ref_before = {1, 0, 2};
45 std::vector<int64_t> order_ref_after = {2, 1, 0, 3};
46 if (order_before_values == order_ref_before && order_after_values == order_ref_after) {
56 ngraph::pass::ConvertLSTMSequenceMatcher::ConvertLSTMSequenceMatcher() {
57 auto lstm_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::opset5::LSTMSequence>();
59 ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
60 auto lstm_sequence = std::dynamic_pointer_cast<ngraph::opset5::LSTMSequence>(m.get_match_root());
65 const auto& W = lstm_sequence->input_value(4);
66 const auto& R = lstm_sequence->input_value(5);
68 // Bidirectional cases are not supported
69 if (lstm_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
72 // Detect pattern: Transpose_before -> Seq -> Transpose_after
73 auto seq_axis = get_seq_axis(lstm_sequence);
74 ngraph::Output<ngraph::Node> in_0 = lstm_sequence->input(0).get_source_output();
76 // input(0) to Transpose_before
77 in_0 = lstm_sequence->get_input_source_output(0).get_node_shared_ptr()->get_input_source_output(0);
79 // for forward/reverse cases we can squeeze num_direction dimension
80 auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
81 auto in_1 = std::make_shared<ngraph::opset5::Squeeze>(lstm_sequence->input_value(1), axis_1);
82 auto in_2 = std::make_shared<ngraph::opset5::Squeeze>(lstm_sequence->input_value(2), axis_1);
83 auto concat = std::make_shared<ngraph::opset5::Concat>(ngraph::OutputVector{W, R}, 2);
84 auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
85 auto in_3 = std::make_shared<ngraph::opset5::Squeeze>(concat->output(0), axis_2);
86 auto in_4 = std::make_shared<ngraph::opset5::Squeeze>(lstm_sequence->input_value(6), axis_2);
87 auto lstm_sequence_ie = std::make_shared<ngraph::op::LSTMSequenceIE>(
89 in_1, // initial_hidden_state
90 in_2, // initial_cell_state
91 lstm_sequence->input_value(3),
94 lstm_sequence->get_hidden_size(),
95 lstm_sequence->get_direction(),
96 lstm_sequence->get_activations(),
97 lstm_sequence->get_activations_alpha(),
98 lstm_sequence->get_activations_beta(),
99 lstm_sequence->get_clip(),
102 auto unsqueeze_axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
103 auto unsqueeze_1 = std::make_shared<ngraph::opset5::Unsqueeze>(lstm_sequence_ie->output(0), unsqueeze_axis);
104 auto unsqueeze_2 = std::make_shared<ngraph::opset5::Unsqueeze>(lstm_sequence_ie->output(1), unsqueeze_axis);
105 auto unsqueeze_3 = std::make_shared<ngraph::opset5::Unsqueeze>(lstm_sequence_ie->output(2), unsqueeze_axis);
107 ngraph::copy_runtime_info(lstm_sequence, {concat, lstm_sequence_ie, in_1, in_2, in_3, in_4, unsqueeze_1,
108 unsqueeze_2, unsqueeze_3});
109 unsqueeze_1->set_friendly_name(lstm_sequence->get_friendly_name()+".0");
110 unsqueeze_2->set_friendly_name(lstm_sequence->get_friendly_name()+".1");
111 unsqueeze_3->set_friendly_name(lstm_sequence->get_friendly_name()+".2");
113 ngraph::replace_node(lstm_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
115 const auto &lstm_target_inputs = lstm_sequence->output(0).get_target_inputs();
116 if (lstm_target_inputs.empty())
118 auto transpose_after = lstm_target_inputs.begin()->get_node()->shared_from_this();
119 ngraph::replace_node(transpose_after, unsqueeze_1);
120 ngraph::replace_node(lstm_sequence, {lstm_sequence_ie->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
125 auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_sequence_ngraph, "ConvertLSTMSequenceToLSTMSequenceIE");
126 this->register_matcher(m, callback);
129 ngraph::pass::ConvertGRUSequenceMatcher::ConvertGRUSequenceMatcher() {
130 auto gru_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::opset5::GRUSequence>();
132 ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
133 auto gru_sequence = std::dynamic_pointer_cast<ngraph::opset5::GRUSequence>(m.get_match_root());
138 auto W = gru_sequence->input_value(3);
139 auto R = gru_sequence->input_value(4);
141 // Bidirectional cases are not supported
142 if (gru_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
145 // Detect pattern: Transpose_before -> Seq -> Transpose_after
146 auto seq_axis = get_seq_axis(gru_sequence);
147 ngraph::Output<ngraph::Node> in_0 = gru_sequence->input(0).get_source_output();
149 // input(0) to Transpose_before
150 in_0 = gru_sequence->get_input_source_output(0).get_node_shared_ptr()->get_input_source_output(0);
152 // for forward/reverse cases we can squeeze num_direction dimension
153 auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
154 auto in_1 = std::make_shared<ngraph::opset5::Squeeze>(gru_sequence->input_value(1), axis_1);
155 auto concat = std::make_shared<ngraph::opset5::Concat>(ngraph::OutputVector{W, R}, 2);
156 auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
157 auto in_3 = std::make_shared<ngraph::opset5::Squeeze>(concat->output(0), axis_2);
158 auto in_4 = std::make_shared<ngraph::opset5::Squeeze>(gru_sequence->input_value(5), axis_2);
160 auto gru_sequence_ie = std::make_shared<ngraph::op::GRUSequenceIE>(
162 in_1, // initial_hidden_state
163 gru_sequence->input_value(2),
166 gru_sequence->get_hidden_size(),
167 gru_sequence->get_direction(),
168 gru_sequence->get_activations(),
169 gru_sequence->get_activations_alpha(),
170 gru_sequence->get_activations_beta(),
171 gru_sequence->get_clip(),
172 gru_sequence->get_linear_before_reset(),
175 auto unsqueeze_axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
176 auto unsqueeze_1 = std::make_shared<ngraph::opset5::Unsqueeze>(gru_sequence_ie->output(0), unsqueeze_axis);
177 auto unsqueeze_2 = std::make_shared<ngraph::opset5::Unsqueeze>(gru_sequence_ie->output(1), unsqueeze_axis);
179 ngraph::copy_runtime_info(gru_sequence, {concat, gru_sequence_ie, unsqueeze_1, unsqueeze_2, in_1, in_3, in_4});
180 unsqueeze_1->set_friendly_name(gru_sequence->get_friendly_name()+".0");
181 unsqueeze_2->set_friendly_name(gru_sequence->get_friendly_name()+".1");
183 ngraph::replace_node(gru_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0)});
185 const auto &gru_target_inputs = gru_sequence->output(0).get_target_inputs();
186 if (gru_target_inputs.empty())
188 auto transpose_after = gru_target_inputs.begin()->get_node()->shared_from_this();
189 ngraph::replace_node(transpose_after, unsqueeze_1);
190 ngraph::replace_node(gru_sequence, {gru_sequence_ie->output(0), unsqueeze_2->output(0)});
195 auto m = std::make_shared<ngraph::pattern::Matcher>(gru_sequence_ngraph, "ConvertGRUSequenceToGRUSequenceIE");
196 this->register_matcher(m, callback);
199 ngraph::pass::ConvertRNNSequenceMatcher::ConvertRNNSequenceMatcher() {
200 auto rnn_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::opset5::RNNSequence>();
202 ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
203 auto rnn_sequence = std::dynamic_pointer_cast<ngraph::opset5::RNNSequence>(m.get_match_root());
208 // Bidirectional cases are not supported
209 if (rnn_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
212 // Detect pattern: Transpose_before -> Seq -> Transpose_after
213 auto seq_axis = get_seq_axis(rnn_sequence);
214 ngraph::Output<ngraph::Node> in_0 = rnn_sequence->input(0).get_source_output();
216 // input(0) to Transpose_before
217 in_0 = rnn_sequence->get_input_source_output(0).get_node_shared_ptr()->get_input_source_output(0);
220 auto W = rnn_sequence->input_value(3);
221 auto R = rnn_sequence->input_value(4);
223 // for forward/reverse cases we can squeeze num_direction dimension
224 auto axis_1 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
225 auto in_1 = std::make_shared<ngraph::opset5::Squeeze>(rnn_sequence->input_value(1), axis_1);
226 auto concat = std::make_shared<ngraph::opset5::Concat>(ngraph::OutputVector{W, R}, 2);
227 auto axis_2 = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
228 auto in_3 = std::make_shared<ngraph::opset5::Squeeze>(concat->output(0), axis_2);
229 auto in_4 = std::make_shared<ngraph::opset5::Squeeze>(rnn_sequence->input_value(5), axis_2);
230 auto rnn_sequence_ie = std::make_shared<ngraph::op::RNNSequenceIE>(
232 in_1, // initial_hidden_state
233 rnn_sequence->input_value(2),
236 rnn_sequence->get_hidden_size(),
237 rnn_sequence->get_direction(),
238 rnn_sequence->get_activations(),
239 rnn_sequence->get_activations_alpha(),
240 rnn_sequence->get_activations_beta(),
241 rnn_sequence->get_clip(),
244 auto unsqueeze_axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
245 auto unsqueeze_1 = std::make_shared<ngraph::opset5::Unsqueeze>(rnn_sequence_ie->output(0), unsqueeze_axis);
246 auto unsqueeze_2 = std::make_shared<ngraph::opset5::Unsqueeze>(rnn_sequence_ie->output(1), unsqueeze_axis);
248 ngraph::copy_runtime_info(rnn_sequence, {concat, rnn_sequence_ie, in_1, in_3, in_4, unsqueeze_1,
250 unsqueeze_1->set_friendly_name(rnn_sequence->get_friendly_name()+".0");
251 unsqueeze_2->set_friendly_name(rnn_sequence->get_friendly_name()+".1");
254 ngraph::replace_node(rnn_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0)});
256 const auto &rnn_target_inputs = rnn_sequence->output(0).get_target_inputs();
257 if (rnn_target_inputs.empty())
259 auto transpose_after = rnn_target_inputs.begin()->get_node()->shared_from_this();
260 ngraph::replace_node(transpose_after, unsqueeze_1);
261 ngraph::replace_node(rnn_sequence, {rnn_sequence_ie->output(0), unsqueeze_2->output(0)});
266 auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_sequence_ngraph, "ConvertRNNSequenceToRNNSequenceIE");
267 this->register_matcher(m, callback);