Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / frontend / onnx_import / src / op / lstm.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <cstddef>
18 #include <cstdint>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "lstm.hpp"
25 #include "ngraph/builder/reshape.hpp"
26 #include "ngraph/builder/split.hpp"
27 #include "ngraph/enum_names.hpp"
28 #include "ngraph/op/add.hpp"
29 #include "ngraph/op/constant.hpp"
30 #include "ngraph/op/lstm_sequence.hpp"
31 #include "ngraph/op/util/attr_types.hpp"
32 #include "ngraph/shape.hpp"
33 #include "ngraph/type/element_type.hpp"
34 #include "onnx_import/core/null_node.hpp"
35 #include "onnx_import/default_opset.hpp"
36 #include "onnx_import/exceptions.hpp"
37 #include "onnx_import/op/lstm.hpp"
38
39 namespace ngraph
40 {
41     namespace onnx_import
42     {
43         namespace op
44         {
45             namespace
46             {
47                 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
48
49                 enum class LSTMInput
50                 {
51                     LSTM_INPUT_X,
52                     LSTM_INPUT_W,
53                     LSTM_INPUT_R,
54                     LSTM_INPUT_B,
55                     LSTM_INPUT_SEQ_LENGTHS,
56                     LSTM_INPUT_INIT_H,
57                     LSTM_INPUT_INIT_C,
58                     LSTM_INPUT_P
59                 };
60
61                 struct LSTMNgInputMap
62                 {
63                     using container_type = std::map<LSTMInput, Output<ngraph::Node>>;
64                     using iterator = typename container_type::iterator;
65
66                     explicit LSTMNgInputMap(const Node& node)
67                     {
68                         const auto& ng_inputs = node.get_ng_inputs();
69                         // We have input, output, forget and cell gates
70                         constexpr std::size_t gates_count{4};
71                         // Peepholes add additional connections to input, output and forget gates.
72                         constexpr std::size_t peepholes_count{3};
73
74                         // ----- Mandatory inputs ------
75                         // Packed input sequences. Shape: [seq_length, batch_size, input_size]
76                         m_map[LSTMInput::LSTM_INPUT_X] =
77                             builder::opset1::reorder_axes(ng_inputs.at(0), {1, 0, 2});
78                         // Weight tensor for the gates.
79                         // Shape: [num_directions, 4*hidden_size, input_size]
80                         m_map[LSTMInput::LSTM_INPUT_W] = ng_inputs.at(1);
81                         // The recurrence weight tensor.
82                         // Shape: [num_directions, 4*hidden_size, hidden_size]
83                         m_map[LSTMInput::LSTM_INPUT_R] = ng_inputs.at(2);
84
85                         const std::size_t hidden_size =
86                             m_map[LSTMInput::LSTM_INPUT_R].get_shape().back();
87                         const std::size_t batch_size =
88                             m_map[LSTMInput::LSTM_INPUT_X].get_shape().at(0);
89                         const std::size_t num_directions =
90                             m_map[LSTMInput::LSTM_INPUT_W].get_shape().front();
91
92                         // ------ Optional inputs ------
93                         // The bias tensor for input gate. Shape [num_directions, 4*hidden_size]
94                         if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
95                         {
96                             auto bias = ng_inputs.at(3);
97                             auto split_bias = builder::opset1::split(bias, 2, 1);
98                             NGRAPH_SUPPRESS_DEPRECATED_START
99                             m_map[LSTMInput::LSTM_INPUT_B] = split_bias.at(0) + split_bias.at(1);
100                             NGRAPH_SUPPRESS_DEPRECATED_END
101                         }
102                         else
103                         {
104                             m_map[LSTMInput::LSTM_INPUT_B] = default_opset::Constant::create(
105                                 element::f32,
106                                 Shape{num_directions, gates_count * hidden_size},
107                                 std::vector<float>(num_directions * gates_count * hidden_size,
108                                                    0.f));
109                         }
110                         // The lengths of the sequences in a batch. Shape [batch_size]
111                         if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
112                         {
113                             m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4);
114                         }
115                         else
116                         {
117                             m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] =
118                                 default_opset::Constant::create(
119                                     element::i32,
120                                     Shape{batch_size},
121                                     std::vector<std::int32_t>(
122                                         batch_size,
123                                         m_map[LSTMInput::LSTM_INPUT_X].get_shape().at(1)));
124                         }
125                         // The initial value of the hidden.
126                         // Shape [num_directions, batch_size, hidden_size]
127                         if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
128                         {
129                             m_map[LSTMInput::LSTM_INPUT_INIT_H] =
130                                 builder::opset1::reorder_axes(ng_inputs.at(5), {1, 0, 2});
131                         }
132                         else
133                         {
134                             m_map[LSTMInput::LSTM_INPUT_INIT_H] = default_opset::Constant::create(
135                                 element::f32,
136                                 Shape{batch_size, num_directions, hidden_size},
137                                 std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
138                         }
139                         // The initial value of the cell.
140                         // Shape [num_directions, batch_size, hidden_size]
141                         if (ng_inputs.size() > 6 && !ngraph::op::is_null(ng_inputs.at(6)))
142                         {
143                             m_map[LSTMInput::LSTM_INPUT_INIT_C] =
144                                 builder::opset1::reorder_axes(ng_inputs.at(6), {1, 0, 2});
145                         }
146                         else
147                         {
148                             m_map[LSTMInput::LSTM_INPUT_INIT_C] = default_opset::Constant::create(
149                                 element::f32,
150                                 Shape{batch_size, num_directions, hidden_size},
151                                 std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
152                         }
153                         // The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
154                         if (ng_inputs.size() > 7 && !ngraph::op::is_null(ng_inputs.at(7)))
155                         {
156                             m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7);
157                         }
158                         else
159                         {
160                             m_map[LSTMInput::LSTM_INPUT_P] = default_opset::Constant::create(
161                                 element::f32,
162                                 Shape{num_directions, peepholes_count * hidden_size},
163                                 std::vector<float>(num_directions * peepholes_count * hidden_size,
164                                                    0.f));
165                         }
166                     }
167
168                     Output<ngraph::Node>& at(const LSTMInput& key) { return m_map.at(key); }
169                     container_type m_map;
170                 };
171
172                 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
173                 struct LSTMAttributes
174                 {
175                     explicit LSTMAttributes(const Node& node)
176                         : m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
177                         , m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
178                         , m_activations{node.get_attribute_value<std::vector<std::string>>(
179                               "activations", {"sigmoid", "tanh", "tanh"})}
180                         // Default values for activation functions are same as for corresponding
181                         // ONNX operator.
182                         , m_activation_alpha{node.get_attribute_value<std::vector<float>>(
183                               "activation_alpha", std::vector<float>{})}
184                         , m_activation_beta{node.get_attribute_value<std::vector<float>>(
185                               "activation_beta", std::vector<float>{})}
186                         , m_input_forget{static_cast<bool>(
187                               node.get_attribute_value<std::int64_t>("input_forget", 0))}
188                     {
189                         m_clip_threshold = std::abs(m_clip_threshold);
190                         std::string direction = ngraph::to_lower(
191                             node.get_attribute_value<std::string>("direction", "forward"));
192
193                         m_direction =
194                             ngraph::as_enum<ngraph::op::RecurrentSequenceDirection>(direction);
195                     }
196
197                     ngraph::op::RecurrentSequenceDirection m_direction;
198                     std::int64_t m_hidden_size;
199                     float m_clip_threshold;
200                     std::vector<std::string> m_activations;
201                     std::vector<float> m_activation_alpha;
202                     std::vector<float> m_activation_beta;
203                     bool m_input_forget;
204                 };
205
206             } // anonymous namespace
207
208             namespace set_1
209             {
210                 OutputVector lstm(const Node& node)
211                 {
212                     LSTMNgInputMap input_map{node};
213                     LSTMAttributes attributes{node};
214
215                     auto lstmSequence = std::make_shared<default_opset::LSTMSequence>(
216                         input_map.at(LSTMInput::LSTM_INPUT_X),
217                         input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
218                         input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
219                         input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
220                         input_map.at(LSTMInput::LSTM_INPUT_W),
221                         input_map.at(LSTMInput::LSTM_INPUT_R),
222                         input_map.at(LSTMInput::LSTM_INPUT_B),
223                         input_map.at(LSTMInput::LSTM_INPUT_P),
224                         attributes.m_hidden_size,
225                         attributes.m_direction,
226                         ngraph::op::LSTMWeightsFormat::IOFC,
227                         attributes.m_activation_alpha,
228                         attributes.m_activation_beta,
229                         attributes.m_activations,
230                         attributes.m_clip_threshold,
231                         attributes.m_input_forget);
232
233                     const auto Y = lstmSequence->output(0);
234                     const auto Y_h = lstmSequence->output(1);
235                     const auto Y_c = lstmSequence->output(2);
236
237                     return {builder::opset1::reorder_axes(Y, {2, 1, 0, 3}),
238                             builder::opset1::reorder_axes(Y_h, {1, 0, 2}),
239                             builder::opset1::reorder_axes(Y_c, {1, 0, 2})};
240                 }
241             } // namespace set_1
242
243         } // namespace op
244
245     } // namespace onnx_import
246
247 } // namespace ngraph