1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_rnn.h"
6 #include "mkldnn_extension_utils.h"
7 #include "desc_iterator.hpp"
12 using namespace mkldnn;
13 using namespace InferenceEngine;
15 namespace MKLDNNPlugin {
17 template <typename T, typename P>
18 inline bool one_of(T val, P item) { return val == item; }
19 template <typename T, typename P, typename... Args>
20 inline bool one_of(T val, P item, Args... item_others) {
21 return val == item || one_of(val, item_others...);
24 using _RNN = RNNSequenceLayer; // alias
26 static rnn_direction ie2mkl(_RNN::Direction &direction) {
27 return direction == _RNN::FWD ? unidirectional_left2right
28 : direction == _RNN::BWD ? unidirectional_right2left
29 : direction == _RNN::BDR ? bidirectional_concat
33 static algorithm ie2mkl(std::string act_type) {
34 return act_type == "sigmoid" ? eltwise_logistic
35 : act_type == "tanh" ? eltwise_tanh
36 : act_type == "relu" ? eltwise_relu
40 static algorithm ie2mkl(RNNCellBase::CellType cell_type) {
42 case RNNCellBase::LSTM: return vanilla_lstm;
43 case RNNCellBase::GRU: return vanilla_gru;
44 case RNNCellBase::GRU_LBR: return gru_linear_before_reset;
45 case RNNCellBase::RNN: return vanilla_rnn;
47 THROW_IE_EXCEPTION << "Unsoupported cell type";
48 return algorithm_undef;
52 MKLDNNRNN::MKLDNNRNN(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {
53 is_cell = one_of(layer->type, "LSTMCell", "GRUCell", "RNNCell");
56 bool MKLDNNRNN::created() const {
57 return getType() == (is_cell ? RNNCell : RNNSeq);
60 void MKLDNNRNN::getSupportedDescriptors() {
67 void MKLDNNRNN::fillCellDesc() {
68 if (!descs.empty()) return;
69 auto cellLayer = std::dynamic_pointer_cast<RNNCellBase>(getCnnLayer());
72 THROW_IE_EXCEPTION << "No original layer for RNNCell.";
74 algorithm cell_type = ie2mkl(cellLayer->cellType);
75 algorithm cell_act = ie2mkl(cellLayer->activations[0]); // Works only for RNN with one gate
77 cell_desc = {cell_type, cell_act};
78 if (cellLayer->clip != 0.0f)
79 cell_desc.set_clipping(cellLayer->clip);
81 auto &ins = cellLayer->insData;
82 auto &outs = cellLayer->outData;
84 if (!one_of(ins.size(), 3, 2))
85 THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
86 if (!one_of(outs.size(), 2, 1))
87 THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
89 auto in_data_dims = getParentEdgeAt(0)->getDims();
90 auto in_h_state_dims = getParentEdgeAt(1)->getDims();
91 auto out_h_state_dims = getChildEdgeAt(0)->getDims();
93 if (in_data_dims.ndims() != 2 || in_h_state_dims.ndims() != 2)
94 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
96 G = cell_desc.get_gates_count();
97 S = cell_desc.get_state_count();
100 DC = in_data_dims[1];
101 SC = in_h_state_dims[1];
103 Gb = (cell_type != gru_linear_before_reset) ? G : G + 1;
106 MKLDNNDims D_shape {N, DC}, S_shape {N, SC};
108 if (in_data_dims != D_shape
109 || in_h_state_dims != S_shape
110 || out_h_state_dims != S_shape)
111 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
114 auto in_c_state_dims = getParentEdgeAt(2)->getDims();
115 auto out_c_state_dims = getChildEdgeAt(1)->getDims();
117 if (in_c_state_dims != S_shape
118 || out_c_state_dims != S_shape)
119 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
122 auto blobs = cellLayer->blobs;
123 Blob::Ptr weights, bias;
124 if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
125 if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
128 THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
130 if (weights->size() != G*SC*(SC+DC))
131 THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
133 if (bias && bias->size() != Gb*SC)
134 THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
136 // Shapes and Attributes are correct. Can start internal stuff initialization.
138 in_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
139 out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
141 in_data_d = {{T, N, DC}, memory::f32, memory::tnc};;
142 out_data_d = {{T, N, SC}, memory::f32, memory::tnc};;
144 w_data_d = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
145 w_state_d = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
148 w_bias_d = {{L, D, Gb, SC}, memory::f32, memory::ldgo};
150 std::vector<TensorDesc> in_candidate, out_candidate;
151 in_candidate.emplace_back(MKLDNNMemoryDesc {D_shape, memory::f32, memory::nc});
152 in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
153 out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
156 in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
157 out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
160 createDescriptor(in_candidate, out_candidate);
163 void MKLDNNRNN::fillSeqDesc() {
164 if (!descs.empty()) return;
165 auto rnnLayer = std::dynamic_pointer_cast<RNNSequenceLayer>(getCnnLayer());
168 THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNSequenceLayer.";
170 if (!one_of(rnnLayer->cellType, _RNN::LSTM, _RNN::GRU, _RNN::GRU_LBR, _RNN::RNN))
171 THROW_IE_EXCEPTION << "RNN layer supports only LSTM/GRU/RNN cell";
173 algorithm cell_type = ie2mkl(rnnLayer->cellType);
174 algorithm cell_act = algorithm_undef;
175 if (!rnnLayer->activations.empty())
176 cell_act = ie2mkl(rnnLayer->activations[0]); // Works only for RNN with one gate
178 cell_desc = {cell_type, cell_act};
180 if (rnnLayer->clip != 0.0f)
181 cell_desc.set_clipping(rnnLayer->clip);
183 if (!one_of(rnnLayer->axis, 0, 1))
184 THROW_IE_EXCEPTION << "RNN layer supports only sequence axis 0 or 1";
185 nativeOrder = rnnLayer->axis == 0;
187 if (!one_of(rnnLayer->direction, _RNN::FWD, _RNN::BWD))
188 THROW_IE_EXCEPTION << "RNN layer supports only unidirectional RNN layer";
189 direction = ie2mkl(rnnLayer->direction);
191 auto &ins = rnnLayer->insData;
192 auto &outs = rnnLayer->outData;
194 if (!one_of(ins.size(), 3, 2, 1))
195 THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
196 if (!one_of(outs.size(), 3, 2, 1))
197 THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
199 auto in_data_dims = getParentEdgeAt(0)->getDims();
200 auto out_data_dims = getChildEdgeAt(0)->getDims();
202 if (in_data_dims.ndims() != 3 || out_data_dims.ndims() != 3)
203 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
206 std::swap(in_data_dims[0], in_data_dims[1]);
207 std::swap(out_data_dims[0], out_data_dims[1]);
210 G = cell_desc.get_gates_count();
211 S = cell_desc.get_state_count();
214 DC = in_data_dims[2];
215 SC = out_data_dims[2];
217 Gb = (cell_type != gru_linear_before_reset) ? G : G + 1;
219 MKLDNNDims ID_shape {T, N, DC}, OD_shape {T, N, SC}, S_shape {N, SC};
221 if (out_data_dims != OD_shape)
222 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
224 if (ins.size() > 1) {
225 for (int i = 1; i < ins.size(); i++)
226 if (getParentEdgeAt(i)->getDims() != S_shape)
227 THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
229 in_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
232 if (outs.size() > 1) {
233 for (int i = 1; i < outs.size(); i++)
234 if (getChildEdgeAt(i)->getDims() != S_shape)
235 THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
237 out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
240 auto blobs = rnnLayer->blobs;
241 Blob::Ptr weights, bias;
242 if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
243 if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
246 THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
248 if (weights->size() != G*SC*(SC+DC))
249 THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
251 w_data_d = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
252 w_state_d = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
254 if (bias && bias->size() != Gb*SC)
255 THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
258 w_bias_d = {{L, D, Gb, SC}, memory::f32, memory::ldgo};
260 // Try to create descriptor and corresponding configuration
261 in_data_d = {in_data_dims, memory::f32, memory::tnc};
262 out_data_d = {out_data_dims, memory::f32, memory::tnc};
264 std::vector<TensorDesc> in_candidate;
266 in_candidate.push_back(in_data_d);
268 in_candidate.push_back(MKLDNNMemoryDesc{{N, T, DC}, memory::f32, memory::ntc});
270 for (int i = 1; i < ins.size(); i++)
271 in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
273 std::vector<TensorDesc> out_candidate;
275 out_candidate.push_back(out_data_d);
277 out_candidate.push_back(MKLDNNMemoryDesc{{N, T, SC}, memory::f32, memory::ntc});
279 for (int i = 1; i < outs.size(); i++)
280 out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
282 createDescriptor(in_candidate, out_candidate);
285 void MKLDNNRNN::createDescriptor(const std::vector<TensorDesc> &inputDesc,
286 const std::vector<TensorDesc> &outputDesc) {
287 MKLDNNDescriptor desc(std::shared_ptr<rnn_forward::desc>(
288 new rnn_forward::desc(forward_scoring, cell_desc,
290 /* In Data */ in_data_d,
291 /* In State */ in_state_d,
292 /* Weights data */ w_data_d,
293 /* Weights state */ w_state_d,
295 /* Out Data */ out_data_d,
296 /* Out State */ out_state_d)));
297 descs.push_back(desc);
299 // Fill supported config
300 InferenceEngine::LayerConfig config;
301 config.dynBatchSupport = false;
302 for (size_t i = 0; i < inputDesc.size(); i++) {
303 InferenceEngine::DataConfig dataConfig;
304 dataConfig.inPlace = -1;
305 dataConfig.constant = false;
306 dataConfig.desc = inputDesc[i];
307 config.inConfs.push_back(dataConfig);
310 for (size_t i = 0; i < outputDesc.size(); i++) {
311 InferenceEngine::DataConfig dataConfig;
312 dataConfig.inPlace = -1;
313 dataConfig.constant = false;
314 dataConfig.desc = outputDesc[i];
315 config.outConfs.push_back(dataConfig);
318 supportedPrimitiveDescriptors.push_back({config, ref_any});
321 void MKLDNNRNN::createPrimitive() {
324 std::shared_ptr<rnn_forward::desc> d = descs[0];
325 rnn_forward::primitive_desc pd(*d, getEngine());
327 auto src_data_mem = getParentEdgeAt(0)->getMemoryPtr();
328 auto dst_data_mem = getChildEdgeAt(0)->getMemoryPtr();
330 // create weight blobs (data and state part)
331 auto w_data_mem = std::make_shared<MKLDNNMemory>(getEngine());
332 w_data_mem->Create(w_data_d);
333 internalBlobMemory.push_back(w_data_mem);
335 auto w_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
336 w_state_mem->Create(w_state_d);
337 internalBlobMemory.push_back(w_state_mem);
339 auto w_bias_mem = std::make_shared<MKLDNNMemory>(getEngine());
340 w_bias_mem->Create(w_bias_d);
341 internalBlobMemory.push_back(w_bias_mem);
346 * W - [gates, out_state_size, in_data_size + in_state_size]
347 * B - [gates, out_state_size]
350 * W - [1, 1, in_date_size, gates, out_state_size]
351 * R - [1, 1, in_state_size, gates, out_state_size]
352 * B - [gates, out_state_size]
356 * Caffe - IFOC, ONNX - IOFC
357 * IE - FICO, mkldnn - IFCO
360 * IE - URO, mkldnn - URO
362 const int gate_map_lstm[] = {1, 0, 2, 3}; // FICO -> IFCO
363 const int gate_map_gru[] = {0, 1, 2, 3};
364 const int gate_map_rnn[] = {0};
366 const int gate_map_lstm_size = sizeof(gate_map_lstm) / sizeof(int);
367 const int gate_map_gru_size = sizeof(gate_map_gru) / sizeof(int);
368 const int gate_map_rnn_size = sizeof(gate_map_rnn) / sizeof(int);
369 if (cell_desc.get_cell_kind() == vanilla_lstm) {
370 gate_map = gate_map_lstm;
371 if (G > gate_map_lstm_size) {
372 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
374 } else if (cell_desc.get_cell_kind() == vanilla_gru) {
375 gate_map = gate_map_gru;
376 if (G > gate_map_gru_size) {
377 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
379 } else if (cell_desc.get_cell_kind() == gru_linear_before_reset) {
380 gate_map = gate_map_gru;
381 if (G > gate_map_gru_size) {
382 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
384 } else if (cell_desc.get_cell_kind() == vanilla_rnn) {
385 gate_map = gate_map_rnn;
386 if (G > gate_map_rnn_size) {
387 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
390 gate_map = gate_map_gru;
391 if (G > gate_map_gru_size) {
392 THROW_IE_EXCEPTION << "G isn't equal to the size of gate_map";
396 auto ie_w_ptr = getCnnLayer()->blobs["weights"]->buffer().as<const float*>();
397 auto w_ptr = static_cast<float*>(w_data_mem->GetData());
398 auto r_ptr = static_cast<float*>(w_state_mem->GetData());
399 const int step = SC * G;
401 for (int g = 0; g < G; g++) {
402 for (int out_i = 0; out_i < SC; out_i++) {
403 float *l_w_ptr = w_ptr + gate_map[g]*SC + out_i;
404 float *l_r_ptr = r_ptr + gate_map[g]*SC+ out_i;
405 for (int in_i = 0; in_i < DC; in_i++) {
406 *l_w_ptr = *ie_w_ptr;
411 for (int in_i = 0; in_i < SC; in_i++) {
412 *l_r_ptr = *ie_w_ptr;
420 auto ie_b_ptr = getCnnLayer()->blobs["biases"]->buffer().as<const float*>();
421 auto b_ptr = static_cast<float*>(w_bias_mem->GetData());
422 for (int g = 0; g < Gb; g++) {
423 float *l_b_ptr = b_ptr + gate_map[g]*SC;
424 for (int out_i = 0; out_i < SC; out_i++) {
425 *l_b_ptr = *ie_b_ptr;
433 auto src_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
434 src_state_mem->Create(in_state_d);
435 internalBlobMemory.push_back(src_state_mem);
438 for (int i = 0; i < S; i++) {
439 /* create copy/concat primitive */
440 auto src_stat = getParentEdgeAt(i+1)->getMemory().GetPrimitive();
442 auto state_mem = std::make_shared<MKLDNNMemory>(getEngine());
444 src_stat.get_primitive_desc().desc(),
445 static_cast<uint8_t *>(src_state_mem->GetPrimitive().get_data_handle()) + offset);
446 offset += src_stat.get_primitive_desc().get_size();
448 internalBlobMemory.push_back(state_mem);
450 exec_before.emplace_back(src_stat, state_mem->GetPrimitive());
454 auto dst_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
455 dst_state_mem->Create(out_state_d);
456 internalBlobMemory.push_back(dst_state_mem);
459 int idx_start = is_cell ? 0 : 1;
460 for (int i = 0; i < S; i++) {
461 /* create copy/split primitive */
462 auto dst_stat = getChildEdgeAt(idx_start + i)->getMemory().GetPrimitive();
464 auto state_mem = std::make_shared<MKLDNNMemory>(getEngine());
466 dst_stat.get_primitive_desc().desc(),
467 static_cast<uint8_t *>(dst_state_mem->GetPrimitive().get_data_handle()) + offset);
468 offset += dst_stat.get_primitive_desc().get_size();
470 internalBlobMemory.push_back(state_mem);
472 if (is_cell && i == 0) continue;
473 exec_after.emplace_back(state_mem->GetPrimitive(), dst_stat);
477 auto workspace_mem = std::make_shared<MKLDNNMemory>(getEngine());
478 workspace_mem->Create({}, memory::f32, memory::format_undef, nullptr); // stub, not in use
479 internalBlobMemory.push_back(workspace_mem);
481 auto p = new rnn_forward(pd,
482 /* In Data */ src_data_mem ->GetPrimitive(),
483 /* In State */ src_state_mem->GetPrimitive(),
484 /* Weights data */ w_data_mem ->GetPrimitive(),
485 /* Weights state */ w_state_mem ->GetPrimitive(),
486 /* Bias */ w_bias_mem ->GetPrimitive(),
487 /* Out Data */ dst_data_mem ->GetPrimitive(),
488 /* Out State */ dst_state_mem->GetPrimitive(),
489 /* Workspace */ workspace_mem->GetPrimitive());
494 void MKLDNNRNN::execute(mkldnn::stream strm) {
495 if (!exec_before.empty())
496 strm.submit({exec_before.begin(), exec_before.end()});
499 strm.submit({*prim});
501 if (!exec_after.empty())
502 strm.submit({exec_after.begin(), exec_after.end()});
505 } // namespace MKLDNNPlugin