1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
25 #include "ngraph/builder/reshape.hpp"
26 #include "ngraph/builder/split.hpp"
27 #include "ngraph/enum_names.hpp"
28 #include "ngraph/op/add.hpp"
29 #include "ngraph/op/constant.hpp"
30 #include "ngraph/op/lstm_sequence.hpp"
31 #include "ngraph/op/util/attr_types.hpp"
32 #include "ngraph/shape.hpp"
33 #include "ngraph/type/element_type.hpp"
34 #include "onnx_import/core/null_node.hpp"
35 #include "onnx_import/default_opset.hpp"
36 #include "onnx_import/exceptions.hpp"
37 #include "onnx_import/op/lstm.hpp"
47 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55 LSTM_INPUT_SEQ_LENGTHS,
63 using container_type = std::map<LSTMInput, Output<ngraph::Node>>;
64 using iterator = typename container_type::iterator;
66 explicit LSTMNgInputMap(const Node& node)
68 const auto& ng_inputs = node.get_ng_inputs();
69 // We have input, output, forget and cell gates
70 constexpr std::size_t gates_count{4};
71 // Peepholes add additional connections to input, output and forget gates.
72 constexpr std::size_t peepholes_count{3};
74 // ----- Mandatory inputs ------
75 // Packed input sequences. Shape: [seq_length, batch_size, input_size]
76 m_map[LSTMInput::LSTM_INPUT_X] =
77 builder::opset1::reorder_axes(ng_inputs.at(0), {1, 0, 2});
78 // Weight tensor for the gates.
79 // Shape: [num_directions, 4*hidden_size, input_size]
80 m_map[LSTMInput::LSTM_INPUT_W] = ng_inputs.at(1);
81 // The recurrence weight tensor.
82 // Shape: [num_directions, 4*hidden_size, hidden_size]
83 m_map[LSTMInput::LSTM_INPUT_R] = ng_inputs.at(2);
85 const std::size_t hidden_size =
86 m_map[LSTMInput::LSTM_INPUT_R].get_shape().back();
87 const std::size_t batch_size =
88 m_map[LSTMInput::LSTM_INPUT_X].get_shape().at(0);
89 const std::size_t num_directions =
90 m_map[LSTMInput::LSTM_INPUT_W].get_shape().front();
92 // ------ Optional inputs ------
93 // The bias tensor for input gate. Shape [num_directions, 4*hidden_size]
94 if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
96 auto bias = ng_inputs.at(3);
97 auto split_bias = builder::opset1::split(bias, 2, 1);
98 NGRAPH_SUPPRESS_DEPRECATED_START
99 m_map[LSTMInput::LSTM_INPUT_B] = split_bias.at(0) + split_bias.at(1);
100 NGRAPH_SUPPRESS_DEPRECATED_END
104 m_map[LSTMInput::LSTM_INPUT_B] = default_opset::Constant::create(
106 Shape{num_directions, gates_count * hidden_size},
107 std::vector<float>(num_directions * gates_count * hidden_size,
110 // The lengths of the sequences in a batch. Shape [batch_size]
111 if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
113 m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4);
117 m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] =
118 default_opset::Constant::create(
121 std::vector<std::int32_t>(
123 m_map[LSTMInput::LSTM_INPUT_X].get_shape().at(1)));
125 // The initial value of the hidden.
126 // Shape [num_directions, batch_size, hidden_size]
127 if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
129 m_map[LSTMInput::LSTM_INPUT_INIT_H] =
130 builder::opset1::reorder_axes(ng_inputs.at(5), {1, 0, 2});
134 m_map[LSTMInput::LSTM_INPUT_INIT_H] = default_opset::Constant::create(
136 Shape{batch_size, num_directions, hidden_size},
137 std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
139 // The initial value of the cell.
140 // Shape [num_directions, batch_size, hidden_size]
141 if (ng_inputs.size() > 6 && !ngraph::op::is_null(ng_inputs.at(6)))
143 m_map[LSTMInput::LSTM_INPUT_INIT_C] =
144 builder::opset1::reorder_axes(ng_inputs.at(6), {1, 0, 2});
148 m_map[LSTMInput::LSTM_INPUT_INIT_C] = default_opset::Constant::create(
150 Shape{batch_size, num_directions, hidden_size},
151 std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
153 // The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
154 if (ng_inputs.size() > 7 && !ngraph::op::is_null(ng_inputs.at(7)))
156 m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7);
160 m_map[LSTMInput::LSTM_INPUT_P] = default_opset::Constant::create(
162 Shape{num_directions, peepholes_count * hidden_size},
163 std::vector<float>(num_directions * peepholes_count * hidden_size,
168 Output<ngraph::Node>& at(const LSTMInput& key) { return m_map.at(key); }
169 container_type m_map;
172 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
173 struct LSTMAttributes
175 explicit LSTMAttributes(const Node& node)
176 : m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
177 , m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
178 , m_activations{node.get_attribute_value<std::vector<std::string>>(
179 "activations", {"sigmoid", "tanh", "tanh"})}
180 // Default values for activation functions are same as for corresponding
182 , m_activation_alpha{node.get_attribute_value<std::vector<float>>(
183 "activation_alpha", std::vector<float>{})}
184 , m_activation_beta{node.get_attribute_value<std::vector<float>>(
185 "activation_beta", std::vector<float>{})}
186 , m_input_forget{static_cast<bool>(
187 node.get_attribute_value<std::int64_t>("input_forget", 0))}
189 m_clip_threshold = std::abs(m_clip_threshold);
190 std::string direction = ngraph::to_lower(
191 node.get_attribute_value<std::string>("direction", "forward"));
194 ngraph::as_enum<ngraph::op::RecurrentSequenceDirection>(direction);
197 ngraph::op::RecurrentSequenceDirection m_direction;
198 std::int64_t m_hidden_size;
199 float m_clip_threshold;
200 std::vector<std::string> m_activations;
201 std::vector<float> m_activation_alpha;
202 std::vector<float> m_activation_beta;
206 } // anonymous namespace
210 OutputVector lstm(const Node& node)
212 LSTMNgInputMap input_map{node};
213 LSTMAttributes attributes{node};
215 auto lstmSequence = std::make_shared<default_opset::LSTMSequence>(
216 input_map.at(LSTMInput::LSTM_INPUT_X),
217 input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
218 input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
219 input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
220 input_map.at(LSTMInput::LSTM_INPUT_W),
221 input_map.at(LSTMInput::LSTM_INPUT_R),
222 input_map.at(LSTMInput::LSTM_INPUT_B),
223 input_map.at(LSTMInput::LSTM_INPUT_P),
224 attributes.m_hidden_size,
225 attributes.m_direction,
226 ngraph::op::LSTMWeightsFormat::IOFC,
227 attributes.m_activation_alpha,
228 attributes.m_activation_beta,
229 attributes.m_activations,
230 attributes.m_clip_threshold,
231 attributes.m_input_forget);
233 const auto Y = lstmSequence->output(0);
234 const auto Y_h = lstmSequence->output(1);
235 const auto Y_c = lstmSequence->output(2);
237 return {builder::opset1::reorder_axes(Y, {2, 1, 0, 3}),
238 builder::opset1::reorder_axes(Y_h, {1, 0, 2}),
239 builder::opset1::reorder_axes(Y_c, {1, 0, 2})};
245 } // namespace onnx_import
247 } // namespace ngraph