Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_rnn.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_rnn.h"
6 #include "mkldnn_extension_utils.h"
7 #include "desc_iterator.hpp"
8
9 #include <string>
10 #include <utility>
11
12 using namespace mkldnn;
13 using namespace InferenceEngine;
14
15 namespace MKLDNNPlugin {
16
17 template <typename T, typename P>
18 inline bool one_of(T val, P item) { return val == item; }
19 template <typename T, typename P, typename... Args>
20 inline bool one_of(T val, P item, Args... item_others) {
21     return val == item || one_of(val, item_others...);
22 }
23
24 using _RNN = RNNSequenceLayer;  // alias
25
26 static rnn_direction ie2mkl(_RNN::Direction &direction) {
27     return direction == _RNN::FWD ? unidirectional_left2right
28          : direction == _RNN::BWD ? unidirectional_right2left
29          : direction == _RNN::BDR ? bidirectional_concat
30          : unidirectional;
31 }
32
33 static algorithm ie2mkl(std::string act_type) {
34     return act_type == "sigmoid" ? eltwise_logistic
35          : act_type == "tanh"    ? eltwise_tanh
36          : act_type == "relu"    ? eltwise_relu
37          : algorithm_undef;
38 }
39
40 static algorithm ie2mkl(RNNCellBase::CellType cell_type) {
41     switch (cell_type) {
42         case RNNCellBase::LSTM: return vanilla_lstm;
43         case RNNCellBase::GRU:  return vanilla_gru;
44         case RNNCellBase::GRU_LBR:  return gru_linear_before_reset;
45         case RNNCellBase::RNN:  return vanilla_rnn;
46         default:
47             THROW_IE_EXCEPTION << "Unsoupported cell type";
48             return algorithm_undef;
49     }
50 }
51
52 MKLDNNRNN::MKLDNNRNN(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {
53     is_cell = one_of(layer->type, "LSTMCell", "GRUCell", "RNNCell");
54 }
55
56 bool MKLDNNRNN::created() const {
57     return getType() == (is_cell ? RNNCell : RNNSeq);
58 }
59
60 void MKLDNNRNN::getSupportedDescriptors() {
61     if (is_cell)
62         fillCellDesc();
63     else
64         fillSeqDesc();
65 }
66
67 void MKLDNNRNN::fillCellDesc() {
68     if (!descs.empty()) return;
69     auto cellLayer = std::dynamic_pointer_cast<RNNCellBase>(getCnnLayer());
70
71     if (!cellLayer)
72         THROW_IE_EXCEPTION << "No original layer for RNNCell.";
73
74     algorithm cell_type = ie2mkl(cellLayer->cellType);
75     algorithm cell_act = ie2mkl(cellLayer->activations[0]);  // Works only for RNN with one gate
76
77     cell_desc = {cell_type, cell_act};
78     if (cellLayer->clip != 0.0f)
79         cell_desc.set_clipping(cellLayer->clip);
80
81     auto &ins = cellLayer->insData;
82     auto &outs = cellLayer->outData;
83
84     if (!one_of(ins.size(), 3, 2))
85         THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
86     if (!one_of(outs.size(), 2, 1))
87         THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
88
89     auto in_data_dims = getParentEdgeAt(0)->getDims();
90     auto in_h_state_dims = getParentEdgeAt(1)->getDims();
91     auto out_h_state_dims = getChildEdgeAt(0)->getDims();
92
93     if (in_data_dims.ndims() != 2 || in_h_state_dims.ndims() != 2)
94         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
95
96     G = cell_desc.get_gates_count();
97     S = cell_desc.get_state_count();
98     T = 1;
99     N  = in_data_dims[0];
100     DC = in_data_dims[1];
101     SC = in_h_state_dims[1];
102
103     Gb = (cell_type != gru_linear_before_reset) ? G : G + 1;
104
105     // Expected shapes
106     MKLDNNDims D_shape {N, DC}, S_shape {N, SC};
107
108     if (in_data_dims != D_shape
109         || in_h_state_dims != S_shape
110         || out_h_state_dims != S_shape)
111         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
112
113     if (S == 2) {
114         auto in_c_state_dims = getParentEdgeAt(2)->getDims();
115         auto out_c_state_dims = getChildEdgeAt(1)->getDims();
116
117         if (in_c_state_dims != S_shape
118             || out_c_state_dims != S_shape)
119             THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
120     }
121
122     auto blobs = cellLayer->blobs;
123     Blob::Ptr weights, bias;
124     if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
125     if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
126
127     if (!weights)
128         THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
129
130     if (weights->size() != G*SC*(SC+DC))
131         THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
132
133     if (bias && bias->size() != Gb*SC)
134         THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
135
136     // Shapes and Attributes are correct. Can start internal stuff initialization.
137
138     in_state_d  = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
139     out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
140
141     in_data_d  = {{T, N, DC}, memory::f32, memory::tnc};;
142     out_data_d = {{T, N, SC}, memory::f32, memory::tnc};;
143
144     w_data_d   = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
145     w_state_d  = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
146
147     if (bias)
148         w_bias_d = {{L, D, Gb, SC}, memory::f32, memory::ldgo};
149
150     std::vector<TensorDesc> in_candidate, out_candidate;
151     in_candidate.emplace_back(MKLDNNMemoryDesc {D_shape, memory::f32, memory::nc});
152     in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
153     out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
154
155     if (S == 2) {
156         in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
157         out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
158     }
159
160     createDescriptor(in_candidate, out_candidate);
161 }
162
163 void MKLDNNRNN::fillSeqDesc() {
164     if (!descs.empty()) return;
165     auto rnnLayer = std::dynamic_pointer_cast<RNNSequenceLayer>(getCnnLayer());
166
167     if (!rnnLayer)
168         THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNSequenceLayer.";
169
170     if (!one_of(rnnLayer->cellType, _RNN::LSTM, _RNN::GRU, _RNN::GRU_LBR, _RNN::RNN))
171         THROW_IE_EXCEPTION << "RNN layer supports only LSTM/GRU/RNN cell";
172
173     algorithm cell_type = ie2mkl(rnnLayer->cellType);
174     algorithm cell_act = algorithm_undef;
175     if (!rnnLayer->activations.empty())
176         cell_act = ie2mkl(rnnLayer->activations[0]);  // Works only for RNN with one gate
177
178     cell_desc = {cell_type, cell_act};
179
180     if (rnnLayer->clip != 0.0f)
181         cell_desc.set_clipping(rnnLayer->clip);
182
183     if (!one_of(rnnLayer->axis, 0, 1))
184         THROW_IE_EXCEPTION << "RNN layer supports only sequence axis 0 or 1";
185     nativeOrder = rnnLayer->axis == 0;
186
187     if (!one_of(rnnLayer->direction, _RNN::FWD, _RNN::BWD))
188         THROW_IE_EXCEPTION << "RNN layer supports only unidirectional RNN layer";
189     direction = ie2mkl(rnnLayer->direction);
190
191     auto &ins = rnnLayer->insData;
192     auto &outs = rnnLayer->outData;
193
194     if (!one_of(ins.size(), 3, 2, 1))
195         THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
196     if (!one_of(outs.size(), 3, 2, 1))
197         THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
198
199     auto in_data_dims = getParentEdgeAt(0)->getDims();
200     auto out_data_dims = getChildEdgeAt(0)->getDims();
201
202     if (in_data_dims.ndims() != 3 || out_data_dims.ndims() != 3)
203         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
204
205     if (!nativeOrder) {
206         std::swap(in_data_dims[0], in_data_dims[1]);
207         std::swap(out_data_dims[0], out_data_dims[1]);
208     }
209
210     G = cell_desc.get_gates_count();
211     S = cell_desc.get_state_count();
212     T = in_data_dims[0];
213     N = in_data_dims[1];
214     DC = in_data_dims[2];
215     SC = out_data_dims[2];
216
217     Gb = (cell_type != gru_linear_before_reset) ? G : G + 1;
218
219     MKLDNNDims ID_shape {T, N, DC}, OD_shape {T, N, SC}, S_shape {N, SC};
220
221     if (out_data_dims != OD_shape)
222         THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
223
224     if (ins.size() > 1) {
225         for (int i = 1; i < ins.size(); i++)
226             if (getParentEdgeAt(i)->getDims() != S_shape)
227                 THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
228
229         in_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
230     }
231
232     if (outs.size() > 1) {
233         for (int i = 1; i < outs.size(); i++)
234             if (getChildEdgeAt(i)->getDims() != S_shape)
235                 THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
236
237         out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
238     }
239
240     auto blobs = rnnLayer->blobs;
241     Blob::Ptr weights, bias;
242     if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
243     if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
244
245     if (!weights)
246         THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
247
248     if (weights->size() != G*SC*(SC+DC))
249         THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
250
251     w_data_d  = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
252     w_state_d = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
253
254     if (bias && bias->size() != Gb*SC)
255         THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
256
257     if (bias)
258         w_bias_d = {{L, D, Gb, SC}, memory::f32, memory::ldgo};
259
260     // Try to create descriptor and corresponding configuration
261     in_data_d = {in_data_dims, memory::f32, memory::tnc};
262     out_data_d = {out_data_dims, memory::f32, memory::tnc};
263
264     std::vector<TensorDesc> in_candidate;
265     if (nativeOrder)
266         in_candidate.push_back(in_data_d);
267     else
268         in_candidate.push_back(MKLDNNMemoryDesc{{N, T, DC}, memory::f32, memory::ntc});
269
270     for (int i = 1; i < ins.size(); i++)
271         in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
272
273     std::vector<TensorDesc> out_candidate;
274     if (nativeOrder)
275         out_candidate.push_back(out_data_d);
276     else
277         out_candidate.push_back(MKLDNNMemoryDesc{{N, T, SC}, memory::f32, memory::ntc});
278
279     for (int i = 1; i < outs.size(); i++)
280         out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
281
282     createDescriptor(in_candidate, out_candidate);
283 }
284
285 void MKLDNNRNN::createDescriptor(const std::vector<TensorDesc> &inputDesc,
286                                  const std::vector<TensorDesc> &outputDesc) {
287     MKLDNNDescriptor desc(std::shared_ptr<rnn_forward::desc>(
288             new rnn_forward::desc(forward_scoring, cell_desc,
289                     direction,
290                     /* In Data       */ in_data_d,
291                     /* In State      */ in_state_d,
292                     /* Weights data  */ w_data_d,
293                     /* Weights state */ w_state_d,
294                     /* Bias          */ w_bias_d,
295                     /* Out Data      */ out_data_d,
296                     /* Out State     */ out_state_d)));
297     descs.push_back(desc);
298
299     // Fill supported config
300     InferenceEngine::LayerConfig config;
301     config.dynBatchSupport = false;
302     for (size_t i = 0; i < inputDesc.size(); i++) {
303         InferenceEngine::DataConfig dataConfig;
304         dataConfig.inPlace = -1;
305         dataConfig.constant = false;
306         dataConfig.desc = inputDesc[i];
307         config.inConfs.push_back(dataConfig);
308     }
309
310     for (size_t i = 0; i < outputDesc.size(); i++) {
311         InferenceEngine::DataConfig dataConfig;
312         dataConfig.inPlace = -1;
313         dataConfig.constant = false;
314         dataConfig.desc = outputDesc[i];
315         config.outConfs.push_back(dataConfig);
316     }
317
318     supportedPrimitiveDescriptors.push_back({config, ref_any});
319 }
320
321 void MKLDNNRNN::createPrimitive() {
322     if (prim) return;
323
324     std::shared_ptr<rnn_forward::desc> d = descs[0];
325     rnn_forward::primitive_desc pd(*d, getEngine());
326
327     auto src_data_mem = getParentEdgeAt(0)->getMemoryPtr();
328     auto dst_data_mem = getChildEdgeAt(0)->getMemoryPtr();
329
330     // create weight blobs (data and state part)
331     auto w_data_mem = std::make_shared<MKLDNNMemory>(getEngine());
332     w_data_mem->Create(w_data_d);
333     internalBlobMemory.push_back(w_data_mem);
334
335     auto w_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
336     w_state_mem->Create(w_state_d);
337     internalBlobMemory.push_back(w_state_mem);
338
339     auto w_bias_mem = std::make_shared<MKLDNNMemory>(getEngine());
340     w_bias_mem->Create(w_bias_d);
341     internalBlobMemory.push_back(w_bias_mem);
342
343     {
344         /* Copy Weight data
345          * IE format:
346          *   W - [gates, out_state_size, in_data_size + in_state_size]
347          *   B - [gates, out_state_size]
348          *
349          * MKLDNN format:
350          *   W - [1, 1, in_date_size,  gates, out_state_size]
351          *   R - [1, 1, in_state_size, gates, out_state_size]
352          *   B - [gates, out_state_size]
353          *
354          *   Gate order
355          *   ====== LSTM ======
356          *   Caffe - IFOC, ONNX   - IOFC
357          *   IE    - FICO, mkldnn - IFCO
358          *
359          *   ====== GRU ======
360          *   IE - URO, mkldnn - URO
361          */
362         const int gate_map_lstm[] = {1, 0, 2, 3};  // FICO -> IFCO
363         const int gate_map_gru[]  = {0, 1, 2, 3};
364         const int gate_map_rnn[]  = {0};
365         const int *gate_map;
366         const int gate_map_lstm_size = sizeof(gate_map_lstm) / sizeof(int);
367         const int gate_map_gru_size = sizeof(gate_map_gru) / sizeof(int);
368         const int gate_map_rnn_size = sizeof(gate_map_rnn) / sizeof(int);
369         if (cell_desc.get_cell_kind() == vanilla_lstm) {
370             gate_map = gate_map_lstm;
371             if (G > gate_map_lstm_size) {
372                 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
373             }
374         } else if (cell_desc.get_cell_kind() == vanilla_gru) {
375             gate_map = gate_map_gru;
376             if (G > gate_map_gru_size) {
377                 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
378             }
379         } else if (cell_desc.get_cell_kind() == gru_linear_before_reset) {
380             gate_map = gate_map_gru;
381             if (G > gate_map_gru_size) {
382                 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
383             }
384         } else if (cell_desc.get_cell_kind() == vanilla_rnn) {
385             gate_map = gate_map_rnn;
386             if (G > gate_map_rnn_size) {
387                 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
388             }
389         } else {
390             gate_map = gate_map_gru;
391             if (G > gate_map_gru_size) {
392                 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
393             }
394         }
395
396         auto ie_w_ptr = getCnnLayer()->blobs["weights"]->buffer().as<const float*>();
397         auto w_ptr = static_cast<float*>(w_data_mem->GetData());
398         auto r_ptr = static_cast<float*>(w_state_mem->GetData());
399         const int step = SC * G;
400
401         for (int g = 0; g < G; g++) {
402             for (int out_i = 0; out_i < SC; out_i++) {
403                 float *l_w_ptr = w_ptr + gate_map[g]*SC + out_i;
404                 float *l_r_ptr = r_ptr + gate_map[g]*SC+ out_i;
405                 for (int in_i = 0; in_i < DC; in_i++) {
406                     *l_w_ptr = *ie_w_ptr;
407                     ie_w_ptr++;
408                     l_w_ptr += step;
409                 }
410
411                 for (int in_i = 0; in_i < SC; in_i++) {
412                     *l_r_ptr = *ie_w_ptr;
413                     ie_w_ptr++;
414                     l_r_ptr += step;
415                 }
416             }
417         }
418
419         if (w_bias_d) {
420             auto ie_b_ptr = getCnnLayer()->blobs["biases"]->buffer().as<const float*>();
421             auto b_ptr = static_cast<float*>(w_bias_mem->GetData());
422             for (int g = 0; g < Gb; g++) {
423                 float *l_b_ptr = b_ptr + gate_map[g]*SC;
424                 for (int out_i = 0; out_i < SC; out_i++) {
425                     *l_b_ptr = *ie_b_ptr;
426                     ie_b_ptr++;
427                     l_b_ptr++;
428                 }
429             }
430         }
431     }
432
433     auto src_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
434     src_state_mem->Create(in_state_d);
435     internalBlobMemory.push_back(src_state_mem);
436     if (in_state_d) {
437         int offset = 0;
438         for (int i = 0; i < S; i++) {
439             /* create copy/concat primitive */
440             auto src_stat = getParentEdgeAt(i+1)->getMemory().GetPrimitive();
441
442             auto state_mem = std::make_shared<MKLDNNMemory>(getEngine());
443             state_mem->Create(
444                     src_stat.get_primitive_desc().desc(),
445                     static_cast<uint8_t *>(src_state_mem->GetPrimitive().get_data_handle()) + offset);
446             offset += src_stat.get_primitive_desc().get_size();
447
448             internalBlobMemory.push_back(state_mem);
449
450             exec_before.emplace_back(src_stat, state_mem->GetPrimitive());
451         }
452     }
453
454     auto dst_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
455     dst_state_mem->Create(out_state_d);
456     internalBlobMemory.push_back(dst_state_mem);
457     if (out_state_d) {
458         int offset = 0;
459         int idx_start = is_cell ? 0 : 1;
460         for (int i = 0; i < S; i++) {
461             /* create copy/split primitive */
462             auto dst_stat = getChildEdgeAt(idx_start + i)->getMemory().GetPrimitive();
463
464             auto state_mem = std::make_shared<MKLDNNMemory>(getEngine());
465             state_mem->Create(
466                     dst_stat.get_primitive_desc().desc(),
467                     static_cast<uint8_t *>(dst_state_mem->GetPrimitive().get_data_handle()) + offset);
468             offset += dst_stat.get_primitive_desc().get_size();
469
470             internalBlobMemory.push_back(state_mem);
471
472             if (is_cell && i == 0) continue;
473             exec_after.emplace_back(state_mem->GetPrimitive(), dst_stat);
474         }
475     }
476
477     auto workspace_mem = std::make_shared<MKLDNNMemory>(getEngine());
478     workspace_mem->Create({}, memory::f32, memory::format_undef, nullptr);  // stub, not in use
479     internalBlobMemory.push_back(workspace_mem);
480
481     auto p = new rnn_forward(pd,
482             /* In Data       */ src_data_mem ->GetPrimitive(),
483             /* In State      */ src_state_mem->GetPrimitive(),
484             /* Weights data  */ w_data_mem   ->GetPrimitive(),
485             /* Weights state */ w_state_mem  ->GetPrimitive(),
486             /* Bias          */ w_bias_mem   ->GetPrimitive(),
487             /* Out Data      */ dst_data_mem ->GetPrimitive(),
488             /* Out State     */ dst_state_mem->GetPrimitive(),
489             /* Workspace     */ workspace_mem->GetPrimitive());
490
491     prim.reset(p);
492 }
493
494 void MKLDNNRNN::execute(mkldnn::stream strm) {
495     if (!exec_before.empty())
496         strm.submit({exec_before.begin(), exec_before.end()});
497
498     if (prim)
499         strm.submit({*prim});
500
501     if (!exec_after.empty())
502         strm.submit({exec_after.begin(), exec_after.end()});
503 }
504
505 }  // namespace MKLDNNPlugin