Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / frontend / onnx_import / src / utils / recurrent.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <cstdint>
18 #include <cstdlib>
19 #include <vector>
20
21 #include "ngraph/builder/autobroadcast.hpp"
22 #include "ngraph/builder/reshape.hpp"
23 #include "ngraph/builder/split.hpp"
24 #include "ngraph/check.hpp"
25 #include "ngraph/enum_names.hpp"
26 #include "onnx_import/core/null_node.hpp"
27 #include "onnx_import/default_opset.hpp"
28 #include "recurrent.hpp"
29
30 namespace ngraph
31 {
32     namespace onnx_import
33     {
34         namespace recurrent
35         {
36             OpInputMap::OpInputMap(const onnx_import::Node& node, std::size_t gates_count)
37             {
38                 const auto& ng_inputs = node.get_ng_inputs();
39
40                 m_map[OpInput::X] = ng_inputs.at(0);
41                 m_map[OpInput::W] = ng_inputs.at(1);
42                 m_map[OpInput::R] = ng_inputs.at(2);
43
44                 const auto el_type = ng_inputs.at(0).get_element_type();
45
46                 const auto x_pshape = m_map[OpInput::X].get_partial_shape();
47                 const auto w_pshape = m_map[OpInput::W].get_partial_shape();
48                 const auto r_pshape = m_map[OpInput::R].get_partial_shape();
49                 NGRAPH_CHECK(x_pshape.rank().is_static() && x_pshape[0].is_static() &&
50                                  x_pshape[1].is_static(),
51                              "RecurrentSequence input X must have static \"seq_length\" and "
52                              "\"batch_size\" dimensions.");
53                 NGRAPH_CHECK(w_pshape.rank().is_static() && w_pshape[0].is_static(),
54                              "RecurrentSequence input W must have static \"num_directions\" "
55                              "(outermost) dimension.");
56                 NGRAPH_CHECK(r_pshape.rank().is_static() && r_pshape[2].is_static(),
57                              "RecurrentSequence input R must have static \"hidden_size\" "
58                              "(innermost) dimension.");
59
60                 const std::size_t hidden_size = m_map[OpInput::R].get_shape().back();
61                 const std::size_t batch_size = m_map[OpInput::X].get_shape().at(1);
62                 const std::size_t num_directions = m_map[OpInput::W].get_shape().front();
63
64                 if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
65                 {
66                     auto bias = ng_inputs.at(3);
67                     auto split_bias = builder::opset1::split(bias, 2, 1);
68                     NGRAPH_SUPPRESS_DEPRECATED_START
69                     m_map[OpInput::B] = split_bias.at(0) + split_bias.at(1);
70                     NGRAPH_SUPPRESS_DEPRECATED_END
71                 }
72                 else
73                 {
74                     m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
75                         el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
76                 }
77                 if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
78                 {
79                     m_map[OpInput::SEQ_LENGTHS] = ng_inputs.at(4);
80                 }
81                 else
82                 {
83                     m_map[OpInput::SEQ_LENGTHS] = std::make_shared<default_opset::Constant>(
84                         element::i32, Shape{batch_size}, m_map[OpInput::X].get_shape().at(0));
85                 }
86                 // The initial value of the hidden.
87                 if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
88                 {
89                     m_map[OpInput::INIT_H] = ng_inputs.at(5);
90                 }
91                 else
92                 {
93                     m_map[OpInput::INIT_H] = std::make_shared<default_opset::Constant>(
94                         el_type, Shape{num_directions, batch_size, hidden_size}, 0.f);
95                 }
96             }
97
98             OpInputMap::OpInputMap(container_type&& map)
99                 : m_map(std::move(map))
100             {
101             }
102
103             Output<ngraph::Node>& OpInputMap::at(const OpInput& key) { return m_map.at(key); }
104             const Output<ngraph::Node>& OpInputMap::at(const OpInput& key) const
105             {
106                 return m_map.at(key);
107             }
108
109             // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
110
111             OpAttributes::OpAttributes(const Node& node)
112                 : m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
113                 , m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
114                 // Recurrent Operators which have more activation functions should override
115                 // this value in constructor of respective Attributes' struct.
116                 , m_activations{node.get_attribute_value<std::vector<std::string>>("activations",
117                                                                                    {"tanh"})}
118                 // Default values for activation functions are same as for corresponding
119                 // ONNX operator.
120                 , m_activations_alpha{node.get_attribute_value<std::vector<float>>(
121                       "activation_alpha", std::vector<float>{})}
122                 , m_activations_beta{node.get_attribute_value<std::vector<float>>(
123                       "activation_beta", std::vector<float>{})}
124             {
125                 m_clip_threshold = std::abs(m_clip_threshold);
126                 std::string direction =
127                     ngraph::to_lower(node.get_attribute_value<std::string>("direction", "forward"));
128                 m_direction = ngraph::as_enum<ngraph::op::RecurrentSequenceDirection>(direction);
129             }
130
131             // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sequence Computations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132
133             RecurrentSequence::RecurrentSequence(OpInputMap& args,
134                                                  ngraph::op::RecurrentSequenceDirection direction)
135                 : m_args(args)
136                 , m_direction(direction)
137             {
138             }
139
140             OutputVector RecurrentSequence::run_sequence(const RecurrentCellFunction& kernel)
141             {
142                 OutputVector results;
143                 if (m_direction == ngraph::op::RecurrentSequenceDirection::FORWARD ||
144                     m_direction == ngraph::op::RecurrentSequenceDirection::REVERSE)
145                 {
146                     results = recurrent_sequence_pass(
147                         kernel, m_direction == ngraph::op::RecurrentSequenceDirection::REVERSE);
148                 }
149                 else if (m_direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
150                 {
151                     OutputVector fwd_results{recurrent_sequence_pass(kernel)};
152                     OutputVector rev_results{recurrent_sequence_pass(kernel, true)};
153
154                     // Stack together respective outputs from both forward and reverse passess.
155                     std::shared_ptr<ngraph::Node> Y{std::make_shared<default_opset::Concat>(
156                         OutputVector{fwd_results.at(0), rev_results.at(0)}, 1)};
157                     results.push_back(Y);
158
159                     std::shared_ptr<ngraph::Node> Y_h{std::make_shared<default_opset::Concat>(
160                         OutputVector{fwd_results.at(1), rev_results.at(1)}, 0)};
161                     results.push_back(Y_h);
162                 }
163                 else
164                 {
165                     throw ngraph_error(
166                         "RecurrentSequence: unhandled direction mode during decomposition.");
167                 }
168                 return results;
169             }
170
171             OutputVector
172                 RecurrentSequence::recurrent_sequence_pass(const RecurrentCellFunction& kernel,
173                                                            bool is_reverse)
174             {
175                 OutputVector h_list;
176
177                 // back-up nodes which we may later modify.
178                 Output<ngraph::Node> orig_W = m_args.at(OpInput::W);
179                 Output<ngraph::Node> orig_R = m_args.at(OpInput::R);
180                 Output<ngraph::Node> orig_B = m_args.at(OpInput::B);
181
182                 Output<ngraph::Node> X = m_args.at(OpInput::X);
183                 Output<ngraph::Node> H_t = prepare_input(m_args.at(OpInput::INIT_H), is_reverse);
184                 Output<ngraph::Node> W = prepare_input(m_args.at(OpInput::W), is_reverse);
185                 Output<ngraph::Node> R = prepare_input(m_args.at(OpInput::R), is_reverse);
186                 Output<ngraph::Node> B = prepare_input(m_args.at(OpInput::B), is_reverse);
187                 Output<ngraph::Node> seq_lengths = m_args.at(OpInput::SEQ_LENGTHS);
188
189                 m_args.at(OpInput::W) = W;
190                 m_args.at(OpInput::R) = R;
191                 m_args.at(OpInput::B) = B;
192
193                 if (is_reverse)
194                 {
195                     X = std::make_shared<default_opset::ReverseSequence>(
196                         X, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
197                 }
198
199                 OutputVector in_seq_steps = builder::opset1::split(X, X.get_shape().at(0));
200
201                 for (auto& in_x : in_seq_steps)
202                 {
203                     // remove first empty dim, after above split.
204                     in_x = builder::opset1::squeeze(in_x);
205                 }
206
207                 int32_t time_step{1};
208                 for (const auto& in_x : in_seq_steps)
209                 {
210                     Output<ngraph::Node> H = kernel(m_args, in_x, H_t);
211
212                     // Expand tensors with empty outermost dim, so we can later concatenate
213                     // them.
214                     // Mask hidden state tensor in order to handle mixed sequence lengths.
215                     // This results in zeroing out values in batches with sequence shorter
216                     // than current time_step.
217                     h_list.push_back(
218                         get_masked_node(builder::opset1::expand_dims(H), time_step, 1));
219
220                     // Here we make sure that only appropriate batches (with respect to its sequence
221                     // length) are updated. Those batches which has shorter sequences preserve
222                     // the last value.
223                     H_t = get_masked_node(H, time_step, 0, H_t);
224                     time_step++;
225                 }
226
227                 // Get back original nodes.
228                 m_args.at(OpInput::W) = orig_W;
229                 m_args.at(OpInput::R) = orig_R;
230                 m_args.at(OpInput::B) = orig_B;
231
232                 // The tensor that concats all the intermediate output values of the hidden.
233                 // It has shape [seq_length, batch_size, hidden_size]
234                 std::shared_ptr<ngraph::Node> Y{std::make_shared<default_opset::Concat>(h_list, 0)};
235
236                 // Get back the original order of the output data.
237                 if (is_reverse)
238                 {
239                     Y = std::make_shared<default_opset::ReverseSequence>(
240                         Y, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
241                 }
242
243                 // Expand Y so that it has expected shape:
244                 // [seq_length, num_directions, batch_size, hidden_size]
245                 Y = builder::opset1::expand_dims(Y, 1);
246
247                 // Expand H_t so that it has expected shape:
248                 // [num_directions, batch_size, hidden_size]
249                 auto Y_h = builder::opset1::expand_dims(H_t);
250
251                 return {Y, Y_h};
252             }
253
254             std::shared_ptr<ngraph::Node>
255                 RecurrentSequence::get_masked_node(const Output<ngraph::Node>& data,
256                                                    int32_t time_step,
257                                                    size_t batch_axis,
258                                                    const Output<ngraph::Node>& default_value) const
259             {
260                 Output<ngraph::Node> mask_value = default_value;
261                 // Create zero mask value node.
262                 if (!mask_value.get_node_shared_ptr())
263                 {
264                     mask_value = std::make_shared<default_opset::Constant>(
265                         data.get_element_type(), data.get_shape(), 0.f);
266                 }
267
268                 // Create predicate nodes. The condition is whether current time step value
269                 // is greater than sequence length for respective batch inputs.
270                 std::shared_ptr<ngraph::Node> curr_time_step_node =
271                     std::make_shared<default_opset::Constant>(
272                         element::i32, data.get_shape(), time_step);
273
274                 Output<ngraph::Node> batch_seq_length =
275                     builder::opset1::legacy_broadcast_for_binary_operation(
276                         curr_time_step_node, m_args.at(OpInput::SEQ_LENGTHS), batch_axis);
277
278                 // Create mask node deciding whether or not to mask batch data.
279                 std::shared_ptr<ngraph::Node> mask_condition =
280                     std::make_shared<default_opset::Greater>(curr_time_step_node, batch_seq_length);
281
282                 // Select values depnding on mask_condition.
283                 // Select(<condition>, <true_value>, <false_value>)
284                 return std::make_shared<default_opset::Select>(mask_condition, mask_value, data);
285             }
286
287             std::shared_ptr<ngraph::Node>
288                 RecurrentSequence::prepare_input(Output<ngraph::Node> node, bool is_reverse) const
289             {
290                 // In bidirectional mode inputs are stacked together, so we must split them.
291                 Output<ngraph::Node> tmp = node;
292                 if (m_direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
293                 {
294                     tmp = builder::opset1::split(node, 2).at(is_reverse ? 1 : 0);
295                 }
296                 // Since we work in forward pass mode, we can squeeze `num_directions` axis from
297                 // input.
298                 return builder::opset1::squeeze(tmp);
299             }
300
301         } // recurrent
302     }     // onnx_import
303 } // ngraph