1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
21 #include "ngraph/builder/autobroadcast.hpp"
22 #include "ngraph/builder/reshape.hpp"
23 #include "ngraph/builder/split.hpp"
24 #include "ngraph/check.hpp"
25 #include "ngraph/enum_names.hpp"
26 #include "onnx_import/core/null_node.hpp"
27 #include "onnx_import/default_opset.hpp"
28 #include "recurrent.hpp"
36 OpInputMap::OpInputMap(const onnx_import::Node& node, std::size_t gates_count)
38 const auto& ng_inputs = node.get_ng_inputs();
40 m_map[OpInput::X] = ng_inputs.at(0);
41 m_map[OpInput::W] = ng_inputs.at(1);
42 m_map[OpInput::R] = ng_inputs.at(2);
44 const auto el_type = ng_inputs.at(0).get_element_type();
46 const auto x_pshape = m_map[OpInput::X].get_partial_shape();
47 const auto w_pshape = m_map[OpInput::W].get_partial_shape();
48 const auto r_pshape = m_map[OpInput::R].get_partial_shape();
49 NGRAPH_CHECK(x_pshape.rank().is_static() && x_pshape[0].is_static() &&
50 x_pshape[1].is_static(),
51 "RecurrentSequence input X must have static \"seq_length\" and "
52 "\"batch_size\" dimensions.");
53 NGRAPH_CHECK(w_pshape.rank().is_static() && w_pshape[0].is_static(),
54 "RecurrentSequence input W must have static \"num_directions\" "
55 "(outermost) dimension.");
56 NGRAPH_CHECK(r_pshape.rank().is_static() && r_pshape[2].is_static(),
57 "RecurrentSequence input R must have static \"hidden_size\" "
58 "(innermost) dimension.");
60 const std::size_t hidden_size = m_map[OpInput::R].get_shape().back();
61 const std::size_t batch_size = m_map[OpInput::X].get_shape().at(1);
62 const std::size_t num_directions = m_map[OpInput::W].get_shape().front();
64 if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
66 auto bias = ng_inputs.at(3);
67 auto split_bias = builder::opset1::split(bias, 2, 1);
68 NGRAPH_SUPPRESS_DEPRECATED_START
69 m_map[OpInput::B] = split_bias.at(0) + split_bias.at(1);
70 NGRAPH_SUPPRESS_DEPRECATED_END
74 m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
75 el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
77 if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
79 m_map[OpInput::SEQ_LENGTHS] = ng_inputs.at(4);
83 m_map[OpInput::SEQ_LENGTHS] = std::make_shared<default_opset::Constant>(
84 element::i32, Shape{batch_size}, m_map[OpInput::X].get_shape().at(0));
86 // The initial value of the hidden.
87 if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
89 m_map[OpInput::INIT_H] = ng_inputs.at(5);
93 m_map[OpInput::INIT_H] = std::make_shared<default_opset::Constant>(
94 el_type, Shape{num_directions, batch_size, hidden_size}, 0.f);
98 OpInputMap::OpInputMap(container_type&& map)
99 : m_map(std::move(map))
103 Output<ngraph::Node>& OpInputMap::at(const OpInput& key) { return m_map.at(key); }
104 const Output<ngraph::Node>& OpInputMap::at(const OpInput& key) const
106 return m_map.at(key);
109 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
111 OpAttributes::OpAttributes(const Node& node)
112 : m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
113 , m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
114 // Recurrent Operators which have more activation functions should override
115 // this value in constructor of respective Attributes' struct.
116 , m_activations{node.get_attribute_value<std::vector<std::string>>("activations",
118 // Default values for activation functions are same as for corresponding
120 , m_activations_alpha{node.get_attribute_value<std::vector<float>>(
121 "activation_alpha", std::vector<float>{})}
122 , m_activations_beta{node.get_attribute_value<std::vector<float>>(
123 "activation_beta", std::vector<float>{})}
125 m_clip_threshold = std::abs(m_clip_threshold);
126 std::string direction =
127 ngraph::to_lower(node.get_attribute_value<std::string>("direction", "forward"));
128 m_direction = ngraph::as_enum<ngraph::op::RecurrentSequenceDirection>(direction);
131 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sequence Computations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133 RecurrentSequence::RecurrentSequence(OpInputMap& args,
134 ngraph::op::RecurrentSequenceDirection direction)
136 , m_direction(direction)
140 OutputVector RecurrentSequence::run_sequence(const RecurrentCellFunction& kernel)
142 OutputVector results;
143 if (m_direction == ngraph::op::RecurrentSequenceDirection::FORWARD ||
144 m_direction == ngraph::op::RecurrentSequenceDirection::REVERSE)
146 results = recurrent_sequence_pass(
147 kernel, m_direction == ngraph::op::RecurrentSequenceDirection::REVERSE);
149 else if (m_direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
151 OutputVector fwd_results{recurrent_sequence_pass(kernel)};
152 OutputVector rev_results{recurrent_sequence_pass(kernel, true)};
154 // Stack together respective outputs from both forward and reverse passess.
155 std::shared_ptr<ngraph::Node> Y{std::make_shared<default_opset::Concat>(
156 OutputVector{fwd_results.at(0), rev_results.at(0)}, 1)};
157 results.push_back(Y);
159 std::shared_ptr<ngraph::Node> Y_h{std::make_shared<default_opset::Concat>(
160 OutputVector{fwd_results.at(1), rev_results.at(1)}, 0)};
161 results.push_back(Y_h);
166 "RecurrentSequence: unhandled direction mode during decomposition.");
172 RecurrentSequence::recurrent_sequence_pass(const RecurrentCellFunction& kernel,
177 // back-up nodes which we may later modify.
178 Output<ngraph::Node> orig_W = m_args.at(OpInput::W);
179 Output<ngraph::Node> orig_R = m_args.at(OpInput::R);
180 Output<ngraph::Node> orig_B = m_args.at(OpInput::B);
182 Output<ngraph::Node> X = m_args.at(OpInput::X);
183 Output<ngraph::Node> H_t = prepare_input(m_args.at(OpInput::INIT_H), is_reverse);
184 Output<ngraph::Node> W = prepare_input(m_args.at(OpInput::W), is_reverse);
185 Output<ngraph::Node> R = prepare_input(m_args.at(OpInput::R), is_reverse);
186 Output<ngraph::Node> B = prepare_input(m_args.at(OpInput::B), is_reverse);
187 Output<ngraph::Node> seq_lengths = m_args.at(OpInput::SEQ_LENGTHS);
189 m_args.at(OpInput::W) = W;
190 m_args.at(OpInput::R) = R;
191 m_args.at(OpInput::B) = B;
195 X = std::make_shared<default_opset::ReverseSequence>(
196 X, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
199 OutputVector in_seq_steps = builder::opset1::split(X, X.get_shape().at(0));
201 for (auto& in_x : in_seq_steps)
203 // remove first empty dim, after above split.
204 in_x = builder::opset1::squeeze(in_x);
207 int32_t time_step{1};
208 for (const auto& in_x : in_seq_steps)
210 Output<ngraph::Node> H = kernel(m_args, in_x, H_t);
212 // Expand tensors with empty outermost dim, so we can later concatenate
214 // Mask hidden state tensor in order to handle mixed sequence lengths.
215 // This results in zeroing out values in batches with sequence shorter
216 // than current time_step.
218 get_masked_node(builder::opset1::expand_dims(H), time_step, 1));
220 // Here we make sure that only appropriate batches (with respect to its sequence
221 // length) are updated. Those batches which has shorter sequences preserve
223 H_t = get_masked_node(H, time_step, 0, H_t);
227 // Get back original nodes.
228 m_args.at(OpInput::W) = orig_W;
229 m_args.at(OpInput::R) = orig_R;
230 m_args.at(OpInput::B) = orig_B;
232 // The tensor that concats all the intermediate output values of the hidden.
233 // It has shape [seq_length, batch_size, hidden_size]
234 std::shared_ptr<ngraph::Node> Y{std::make_shared<default_opset::Concat>(h_list, 0)};
236 // Get back the original order of the output data.
239 Y = std::make_shared<default_opset::ReverseSequence>(
240 Y, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
243 // Expand Y so that it has expected shape:
244 // [seq_length, num_directions, batch_size, hidden_size]
245 Y = builder::opset1::expand_dims(Y, 1);
247 // Expand H_t so that it has expected shape:
248 // [num_directions, batch_size, hidden_size]
249 auto Y_h = builder::opset1::expand_dims(H_t);
254 std::shared_ptr<ngraph::Node>
255 RecurrentSequence::get_masked_node(const Output<ngraph::Node>& data,
258 const Output<ngraph::Node>& default_value) const
260 Output<ngraph::Node> mask_value = default_value;
261 // Create zero mask value node.
262 if (!mask_value.get_node_shared_ptr())
264 mask_value = std::make_shared<default_opset::Constant>(
265 data.get_element_type(), data.get_shape(), 0.f);
268 // Create predicate nodes. The condition is whether current time step value
269 // is greater than sequence length for respective batch inputs.
270 std::shared_ptr<ngraph::Node> curr_time_step_node =
271 std::make_shared<default_opset::Constant>(
272 element::i32, data.get_shape(), time_step);
274 Output<ngraph::Node> batch_seq_length =
275 builder::opset1::legacy_broadcast_for_binary_operation(
276 curr_time_step_node, m_args.at(OpInput::SEQ_LENGTHS), batch_axis);
278 // Create mask node deciding whether or not to mask batch data.
279 std::shared_ptr<ngraph::Node> mask_condition =
280 std::make_shared<default_opset::Greater>(curr_time_step_node, batch_seq_length);
282 // Select values depnding on mask_condition.
283 // Select(<condition>, <true_value>, <false_value>)
284 return std::make_shared<default_opset::Select>(mask_condition, mask_value, data);
287 std::shared_ptr<ngraph::Node>
288 RecurrentSequence::prepare_input(Output<ngraph::Node> node, bool is_reverse) const
290 // In bidirectional mode inputs are stacked together, so we must split them.
291 Output<ngraph::Node> tmp = node;
292 if (m_direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
294 tmp = builder::opset1::split(node, 2).at(is_reverse ? 1 : 0);
296 // Since we work in forward pass mode, we can squeeze `num_directions` axis from
298 return builder::opset1::squeeze(tmp);