Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_rnn.cpp
1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_rnn.h"
6 #include "mkldnn_extension_utils.h"
7 #include "desc_iterator.hpp"
8 #include <ie_layers_prv.h>
9
10 #include <string>
11 #include <utility>
12
13 using namespace mkldnn;
14 using namespace InferenceEngine;
15
16 namespace MKLDNNPlugin {
17
18 template <typename T, typename P>
19 inline bool one_of(T val, P item) { return val == item; }
20 template <typename T, typename P, typename... Args>
21 inline bool one_of(T val, P item, Args... item_others) {
22     return val == item || one_of(val, item_others...);
23 }
24
25 rnn_direction ie2mkl(RNNLayer::Direction &direction) {
26     return direction == RNNLayer::RNN_FWD ? unidirectional_left2right
27          : direction == RNNLayer::RNN_BWD ? unidirectional_right2left
28          : direction == RNNLayer::RNN_BDR ? bidirectional_concat
29                                           : unidirectional;
30 }
31
32 MKLDNNRNN::MKLDNNRNN(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {
33     is_cell = layer->type == "LSTMCell";
34 }
35
36 bool MKLDNNRNN::created() const {
37     return getType() == (is_cell ? LSTMCell : RNN);
38 }
39
40 void MKLDNNRNN::getSupportedDescriptors() {
41     if (is_cell)
42         fillCellDesc();
43     else
44         fillSeqDesc();
45 }
46
47 void MKLDNNRNN::fillCellDesc() {
48     if (!descs.empty()) return;
49     auto cellLayer = std::dynamic_pointer_cast<InferenceEngine::LSTMCell>(getCnnLayer());
50
51     if (!cellLayer)
52         THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNLayer.";
53
54     auto &ins = cellLayer->insData;
55     auto &outs = cellLayer->outData;
56
57     if (ins.size() != 3)
58         THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
59     if (outs.size() != 2)
60         THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
61
62     auto in_data_dims = getParentEdgeAt(0)->getDims();
63     auto in_h_state_dims = getParentEdgeAt(1)->getDims();
64     auto in_c_state_dims = getParentEdgeAt(2)->getDims();
65
66     auto out_h_state_dims = getChildEdgeAt(0)->getDims();
67     auto out_c_state_dims = getChildEdgeAt(1)->getDims();
68
69     if (in_data_dims.ndims() != 2
70         || in_h_state_dims.ndims() != 2
71         || in_c_state_dims.ndims() != 2
72         || out_h_state_dims.ndims() != 2
73         || out_c_state_dims.ndims() != 2)
74         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
75
76     T = 1;
77     N  = in_data_dims[0];
78     DC = in_data_dims[1];
79     SC = in_h_state_dims[1];
80
81     // Expected shapes
82     MKLDNNDims D_shape {N, DC}, S_shape {N, SC};
83
84     if (in_data_dims != D_shape
85         || in_h_state_dims != S_shape
86         || in_c_state_dims != S_shape
87         || out_h_state_dims != S_shape
88         || out_c_state_dims != S_shape)
89         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
90
91     auto blobs = cellLayer->blobs;
92     Blob::Ptr weights, bias;
93     if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
94     if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
95
96     if (!weights)
97         THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
98
99     if (weights->size() != G*SC*(SC+DC))
100         THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
101
102     if (bias && bias->size() != G*SC)
103         THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
104
105     // Shapes and Attributes are correct. Can start internal stuff initialization.
106
107     in_state_d  = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
108     out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
109
110     in_data_d  = {{T, N, DC}, memory::f32, memory::tnc};;
111     out_data_d = {{T, N, SC}, memory::f32, memory::tnc};;
112
113     w_data_d   = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
114     w_state_d  = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
115
116     if (bias)
117         w_bias_d = {{L, D, G, SC}, memory::f32, memory::ldgo};
118
119     std::vector<TensorDesc> in_candidate;
120     in_candidate.emplace_back(MKLDNNMemoryDesc {D_shape, memory::f32, memory::nc});
121     in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
122     in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
123
124     std::vector<TensorDesc> out_candidate;
125     out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
126     out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
127
128     createDescriptor(in_candidate, out_candidate);
129 }
130
131 void MKLDNNRNN::fillSeqDesc() {
132     if (!descs.empty()) return;
133     auto rnnLayer = std::dynamic_pointer_cast<RNNLayer>(getCnnLayer());
134
135     if (!rnnLayer)
136         THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNLayer.";
137
138     if (!one_of(rnnLayer->cellType, "LSTM"))
139         THROW_IE_EXCEPTION << "RNN layer supports only LSTM like cell";
140
141     if (!one_of(rnnLayer->axis, 0, 1))
142         THROW_IE_EXCEPTION << "RNN layer supports only sequence axis 0 or 1";
143     nativeOrder = rnnLayer->axis == 0;
144
145     if (!one_of(rnnLayer->direction, RNNLayer::RNN_FWD, RNNLayer::RNN_BWD))
146         THROW_IE_EXCEPTION << "RNN layer supports only unidirectional RNN layer";
147     direction = ie2mkl(rnnLayer->direction);
148
149     auto &ins = rnnLayer->insData;
150     auto &outs = rnnLayer->outData;
151
152     if (!one_of(ins.size(), 3, 1))
153         THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
154     if (!one_of(outs.size(), 3, 1))
155         THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
156
157     auto in_data_dims = getParentEdgeAt(0)->getDims();
158     auto out_data_dims = getChildEdgeAt(0)->getDims();
159
160     if (in_data_dims.ndims() != 3 || out_data_dims.ndims() != 3)
161         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
162
163     if (!nativeOrder) {
164         std::swap(in_data_dims[0], in_data_dims[1]);
165         std::swap(out_data_dims[0], out_data_dims[1]);
166     }
167
168     T = in_data_dims[0];
169     N = in_data_dims[1];
170     DC = in_data_dims[2];
171     SC = out_data_dims[2];
172
173     MKLDNNDims ID_shape {T, N, DC}, OD_shape {T, N, SC}, S_shape {N, SC};
174
175     if (out_data_dims != OD_shape)
176         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
177
178     if (ins.size() == 3) {
179         auto state_dims1 = getParentEdgeAt(1)->getDims();
180         auto stats_dims2 = getParentEdgeAt(2)->getDims();
181
182         if (state_dims1 != S_shape || stats_dims2 != S_shape)
183             THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
184
185         in_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
186     }
187
188     if (outs.size() == 3) {
189         auto state_dims1 = getChildEdgeAt(1)->getDims();
190         auto stats_dims2 = getChildEdgeAt(2)->getDims();
191
192         if (state_dims1 != S_shape || stats_dims2 != S_shape)
193             THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
194
195         out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
196     }
197
198     auto blobs = rnnLayer->blobs;
199     Blob::Ptr weights, bias;
200     if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
201     if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
202
203     if (!weights)
204         THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
205
206     if (weights->size() != G*SC*(SC+DC))
207         THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
208
209     w_data_d  = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
210     w_state_d = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
211
212     if (bias && bias->size() != G*SC)
213         THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
214
215     if (bias)
216         w_bias_d = {{L, D, G, SC}, memory::f32, memory::ldgo};
217
218     // Try to create descriptor and corresponding configuration
219     in_data_d = {in_data_dims, memory::f32, memory::tnc};
220     out_data_d = {out_data_dims, memory::f32, memory::tnc};
221
222     std::vector<TensorDesc> in_candidate;
223     if (nativeOrder)
224         in_candidate.push_back(in_data_d);
225     else
226         in_candidate.push_back(MKLDNNMemoryDesc{{N, T, DC}, memory::f32, memory::ntc});
227
228     if (ins.size() == 3) {
229         in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
230         in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
231     }
232
233     std::vector<TensorDesc> out_candidate;
234     if (nativeOrder)
235         out_candidate.push_back(out_data_d);
236     else
237         out_candidate.push_back(MKLDNNMemoryDesc{{N, T, SC}, memory::f32, memory::ntc});
238
239     if (outs.size() == 3) {
240         out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
241         out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
242     }
243
244     createDescriptor(in_candidate, out_candidate);
245 }
246
247 void MKLDNNRNN::createDescriptor(const std::vector<TensorDesc> &inputDesc,
248                                  const std::vector<TensorDesc> &outputDesc) {
249     MKLDNNDescriptor desc(std::shared_ptr<rnn_forward::desc>(
250             new rnn_forward::desc(forward_scoring,
251                     {algorithm::vanilla_lstm, algorithm::eltwise_tanh },
252                     direction,
253                     /* In Data       */ in_data_d,
254                     /* In State      */ in_state_d,
255                     /* Weights data  */ w_data_d,
256                     /* Weights state */ w_state_d,
257                     /* Bias          */ w_bias_d,
258                     /* Out Data      */ out_data_d,
259                     /* Out State     */ out_state_d)));
260     descs.push_back(desc);
261
262     // Fill supported config
263     InferenceEngine::LayerConfig config;
264     config.dynBatchSupport = false;
265     for (size_t i = 0; i < inputDesc.size(); i++) {
266         InferenceEngine::DataConfig dataConfig;
267         dataConfig.inPlace = -1;
268         dataConfig.constant = false;
269         dataConfig.desc = inputDesc[i];
270         config.inConfs.push_back(dataConfig);
271     }
272
273     for (size_t i = 0; i < outputDesc.size(); i++) {
274         InferenceEngine::DataConfig dataConfig;
275         dataConfig.inPlace = -1;
276         dataConfig.constant = false;
277         dataConfig.desc = outputDesc[i];
278         config.outConfs.push_back(dataConfig);
279     }
280
281     supportedPrimitiveDescriptors.push_back({config, ref_any});
282 }
283
284 void MKLDNNRNN::createPrimitive() {
285     if (prim) return;
286
287     std::shared_ptr<rnn_forward::desc> d = descs[0];
288     rnn_forward::primitive_desc pd(*d, getEngine());
289
290     auto src_data_mem = getParentEdgeAt(0)->getMemoryPtr();
291     auto dst_data_mem = getChildEdgeAt(0)->getMemoryPtr();
292
293     // create weight blobs (data and state part)
294     auto w_data_mem = std::make_shared<MKLDNNMemory>(getEngine());
295     w_data_mem->Create(w_data_d);
296     internalBlobMemory.push_back(w_data_mem);
297
298     auto w_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
299     w_state_mem->Create(w_state_d);
300     internalBlobMemory.push_back(w_state_mem);
301
302     auto w_bias_mem = std::make_shared<MKLDNNMemory>(getEngine());
303     w_bias_mem->Create(w_bias_d);
304     internalBlobMemory.push_back(w_bias_mem);
305
306     {
307         /* Copy Weight data
308          *
309          * IE format:
310          *   W - [gates, out_state_size, in_data_size + in_state_size]
311          *   B - [gates, out_state_size]
312          *
313          * MKLDNN format:
314          *   W - [1, 1, in_date_size,  gates, out_state_size]
315          *   R - [1, 1, in_state_size, gates, out_state_size]
316          *   B - [gates, out_state_size]
317          *
318          *   Gate order
319          *   Caffe - IFOC, ONNX   - IOFC
320          *   IE    - FICO, mkldnn - IFCO
321          */
322         // FICO -> IFCO
323         const int gate_map[] = {1, 0, 2, 3};
324
325         auto ie_w_ptr = getCnnLayer()->blobs["weights"]->buffer().as<const float*>();
326         auto w_ptr = static_cast<float*>(w_data_mem->GetData());
327         auto r_ptr = static_cast<float*>(w_state_mem->GetData());
328         const int step = SC * G;
329
330         for (int g = 0; g < G; g++) {
331             for (int out_i = 0; out_i < SC; out_i++) {
332                 float *l_w_ptr = w_ptr + gate_map[g]*SC + out_i;
333                 float *l_r_ptr = r_ptr + gate_map[g]*SC+ out_i;
334                 for (int in_i = 0; in_i < DC; in_i++) {
335                     *l_w_ptr = *ie_w_ptr;
336                     ie_w_ptr++;
337                     l_w_ptr += step;
338                 }
339
340                 for (int in_i = 0; in_i < SC; in_i++) {
341                     *l_r_ptr = *ie_w_ptr;
342                     ie_w_ptr++;
343                     l_r_ptr += step;
344                 }
345             }
346         }
347
348         if (w_bias_d) {
349             auto ie_b_ptr = getCnnLayer()->blobs["biases"]->buffer().as<const float*>();
350             auto b_ptr = static_cast<float*>(w_bias_mem->GetData());
351             for (int g = 0; g < G; g++) {
352                 float *l_b_ptr = b_ptr + gate_map[g]*SC;
353                 for (int out_i = 0; out_i < SC; out_i++) {
354                     *l_b_ptr = *ie_b_ptr;
355                     ie_b_ptr++;
356                     l_b_ptr++;
357                 }
358             }
359         }
360     }
361
362     auto src_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
363     src_state_mem->Create(in_state_d);
364     internalBlobMemory.push_back(src_state_mem);
365     if (in_state_d) {
366         /* create copy/concat primitive */
367         auto src_stat_1 = getParentEdgeAt(1)->getMemory().GetPrimitive();
368         auto src_stat_2 = getParentEdgeAt(2)->getMemory().GetPrimitive();
369
370         auto low_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
371         low_half_state_mem->Create(
372                 src_stat_1.get_primitive_desc().desc(),
373                 src_state_mem->GetPrimitive().get_data_handle());
374         internalBlobMemory.push_back(low_half_state_mem);
375
376         auto high_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
377         high_half_state_mem->Create(
378                 src_stat_2.get_primitive_desc().desc(),
379                 static_cast<uint8_t*>(src_state_mem->GetPrimitive().get_data_handle()) +
380                 src_stat_1.get_primitive_desc().get_size());
381         internalBlobMemory.push_back(high_half_state_mem);
382
383         exec_before.emplace_back(src_stat_1, low_half_state_mem->GetPrimitive());
384         exec_before.emplace_back(src_stat_2, high_half_state_mem->GetPrimitive());
385     }
386
387     auto dst_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
388     dst_state_mem->Create(out_state_d);
389     internalBlobMemory.push_back(dst_state_mem);
390     if (out_state_d) {
391         int idx_H = is_cell ? 0 : 1;
392         int idx_C = is_cell ? 1 : 2;
393         /* create copy/split primitive */
394         auto dst_stat_1 = getChildEdgeAt(idx_H)->getMemory().GetPrimitive();
395         auto dst_stat_2 = getChildEdgeAt(idx_C)->getMemory().GetPrimitive();
396
397         auto low_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
398         low_half_state_mem->Create(
399                 dst_stat_1.get_primitive_desc().desc(),
400                 dst_state_mem->GetPrimitive().get_data_handle());
401         internalBlobMemory.push_back(low_half_state_mem);
402
403         auto high_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
404         high_half_state_mem->Create(
405                 dst_stat_2.get_primitive_desc().desc(),
406                 static_cast<uint8_t*>(dst_state_mem->GetPrimitive().get_data_handle()) +
407                         dst_stat_1.get_primitive_desc().get_size());
408         internalBlobMemory.push_back(high_half_state_mem);
409
410
411         if (!is_cell) exec_after.emplace_back(low_half_state_mem->GetPrimitive(),  dst_stat_1);
412         exec_after.emplace_back(high_half_state_mem->GetPrimitive(), dst_stat_2);
413     }
414
415     auto workspace_mem = std::make_shared<MKLDNNMemory>(getEngine());
416     workspace_mem->Create({}, memory::f32, memory::format_undef, nullptr);  // stub, not in use
417     internalBlobMemory.push_back(workspace_mem);
418
419     auto p = new rnn_forward(pd,
420             /* In Data       */ src_data_mem ->GetPrimitive(),
421             /* In State      */ src_state_mem->GetPrimitive(),
422             /* Weights data  */ w_data_mem   ->GetPrimitive(),
423             /* Weights state */ w_state_mem  ->GetPrimitive(),
424             /* Bias          */ w_bias_mem   ->GetPrimitive(),
425             /* Out Data      */ dst_data_mem ->GetPrimitive(),
426             /* Out State     */ dst_state_mem->GetPrimitive(),
427             /* Workspace     */ workspace_mem->GetPrimitive());
428
429     prim.reset(p);
430 }
431
432 void MKLDNNRNN::execute(mkldnn::stream strm) {
433     if (!exec_before.empty())
434         strm.submit({exec_before.begin(), exec_before.end()});
435
436     if (prim)
437         strm.submit({*prim});
438
439     if (!exec_after.empty())
440         strm.submit({exec_after.begin(), exec_after.end()});
441 }
442
443 }  // namespace MKLDNNPlugin