1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_rnn.h"
6 #include "mkldnn_extension_utils.h"
7 #include "desc_iterator.hpp"
8 #include <ie_layers_prv.h>
13 using namespace mkldnn;
14 using namespace InferenceEngine;
16 namespace MKLDNNPlugin {
18 template <typename T, typename P>
19 inline bool one_of(T val, P item) { return val == item; }
20 template <typename T, typename P, typename... Args>
21 inline bool one_of(T val, P item, Args... item_others) {
22 return val == item || one_of(val, item_others...);
25 rnn_direction ie2mkl(RNNLayer::Direction &direction) {
26 return direction == RNNLayer::RNN_FWD ? unidirectional_left2right
27 : direction == RNNLayer::RNN_BWD ? unidirectional_right2left
28 : direction == RNNLayer::RNN_BDR ? bidirectional_concat
32 MKLDNNRNN::MKLDNNRNN(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {
33 is_cell = layer->type == "LSTMCell";
36 bool MKLDNNRNN::created() const {
37 return getType() == (is_cell ? LSTMCell : RNN);
40 void MKLDNNRNN::getSupportedDescriptors() {
47 void MKLDNNRNN::fillCellDesc() {
48 if (!descs.empty()) return;
49 auto cellLayer = std::dynamic_pointer_cast<InferenceEngine::LSTMCell>(getCnnLayer());
52 THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNLayer.";
54 auto &ins = cellLayer->insData;
55 auto &outs = cellLayer->outData;
58 THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
60 THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
62 auto in_data_dims = getParentEdgeAt(0)->getDims();
63 auto in_h_state_dims = getParentEdgeAt(1)->getDims();
64 auto in_c_state_dims = getParentEdgeAt(2)->getDims();
66 auto out_h_state_dims = getChildEdgeAt(0)->getDims();
67 auto out_c_state_dims = getChildEdgeAt(1)->getDims();
69 if (in_data_dims.ndims() != 2
70 || in_h_state_dims.ndims() != 2
71 || in_c_state_dims.ndims() != 2
72 || out_h_state_dims.ndims() != 2
73 || out_c_state_dims.ndims() != 2)
74 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
79 SC = in_h_state_dims[1];
82 MKLDNNDims D_shape {N, DC}, S_shape {N, SC};
84 if (in_data_dims != D_shape
85 || in_h_state_dims != S_shape
86 || in_c_state_dims != S_shape
87 || out_h_state_dims != S_shape
88 || out_c_state_dims != S_shape)
89 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
91 auto blobs = cellLayer->blobs;
92 Blob::Ptr weights, bias;
93 if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
94 if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
97 THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
99 if (weights->size() != G*SC*(SC+DC))
100 THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
102 if (bias && bias->size() != G*SC)
103 THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
105 // Shapes and Attributes are correct. Can start internal stuff initialization.
107 in_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
108 out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
110 in_data_d = {{T, N, DC}, memory::f32, memory::tnc};;
111 out_data_d = {{T, N, SC}, memory::f32, memory::tnc};;
113 w_data_d = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
114 w_state_d = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
117 w_bias_d = {{L, D, G, SC}, memory::f32, memory::ldgo};
119 std::vector<TensorDesc> in_candidate;
120 in_candidate.emplace_back(MKLDNNMemoryDesc {D_shape, memory::f32, memory::nc});
121 in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
122 in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
124 std::vector<TensorDesc> out_candidate;
125 out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
126 out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
128 createDescriptor(in_candidate, out_candidate);
131 void MKLDNNRNN::fillSeqDesc() {
132 if (!descs.empty()) return;
133 auto rnnLayer = std::dynamic_pointer_cast<RNNLayer>(getCnnLayer());
136 THROW_IE_EXCEPTION << "Wrong RNN layer representation. Cannot cast to RNNLayer.";
138 if (!one_of(rnnLayer->cellType, "LSTM"))
139 THROW_IE_EXCEPTION << "RNN layer supports only LSTM like cell";
141 if (!one_of(rnnLayer->axis, 0, 1))
142 THROW_IE_EXCEPTION << "RNN layer supports only sequence axis 0 or 1";
143 nativeOrder = rnnLayer->axis == 0;
145 if (!one_of(rnnLayer->direction, RNNLayer::RNN_FWD, RNNLayer::RNN_BWD))
146 THROW_IE_EXCEPTION << "RNN layer supports only unidirectional RNN layer";
147 direction = ie2mkl(rnnLayer->direction);
149 auto &ins = rnnLayer->insData;
150 auto &outs = rnnLayer->outData;
152 if (!one_of(ins.size(), 3, 1))
153 THROW_IE_EXCEPTION << "Incorrect number of input ports for layer " << getName();
154 if (!one_of(outs.size(), 3, 1))
155 THROW_IE_EXCEPTION << "Incorrect number of output ports for layer " << getName();
157 auto in_data_dims = getParentEdgeAt(0)->getDims();
158 auto out_data_dims = getChildEdgeAt(0)->getDims();
160 if (in_data_dims.ndims() != 3 || out_data_dims.ndims() != 3)
161 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
164 std::swap(in_data_dims[0], in_data_dims[1]);
165 std::swap(out_data_dims[0], out_data_dims[1]);
170 DC = in_data_dims[2];
171 SC = out_data_dims[2];
173 MKLDNNDims ID_shape {T, N, DC}, OD_shape {T, N, SC}, S_shape {N, SC};
175 if (out_data_dims != OD_shape)
176 THROW_IE_EXCEPTION << "Incorrect shape of input/output ports for layer " << getName();
178 if (ins.size() == 3) {
179 auto state_dims1 = getParentEdgeAt(1)->getDims();
180 auto stats_dims2 = getParentEdgeAt(2)->getDims();
182 if (state_dims1 != S_shape || stats_dims2 != S_shape)
183 THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
185 in_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
188 if (outs.size() == 3) {
189 auto state_dims1 = getChildEdgeAt(1)->getDims();
190 auto stats_dims2 = getChildEdgeAt(2)->getDims();
192 if (state_dims1 != S_shape || stats_dims2 != S_shape)
193 THROW_IE_EXCEPTION << "Incorrect shape of state ports for layer " << getName();
195 out_state_d = {{L, D, S, N, SC}, memory::f32, memory::ldsnc};
198 auto blobs = rnnLayer->blobs;
199 Blob::Ptr weights, bias;
200 if (blobs.find("weights") != blobs.end()) weights = blobs["weights"];
201 if (blobs.find("biases") != blobs.end()) bias = blobs["biases"];
204 THROW_IE_EXCEPTION << "RNN Layer. Weights do not present.";
206 if (weights->size() != G*SC*(SC+DC))
207 THROW_IE_EXCEPTION << "RNN Layer. Weights size is not correct. Expected size:" << G*SC*(SC+DC);
209 w_data_d = {{L, D, DC, G, SC}, memory::f32, memory::ldigo};
210 w_state_d = {{L, D, SC, G, SC}, memory::f32, memory::ldigo};
212 if (bias && bias->size() != G*SC)
213 THROW_IE_EXCEPTION << "RNN Layer. Biases size is not correct. Expected size:" << G*SC;
216 w_bias_d = {{L, D, G, SC}, memory::f32, memory::ldgo};
218 // Try to create descriptor and corresponding configuration
219 in_data_d = {in_data_dims, memory::f32, memory::tnc};
220 out_data_d = {out_data_dims, memory::f32, memory::tnc};
222 std::vector<TensorDesc> in_candidate;
224 in_candidate.push_back(in_data_d);
226 in_candidate.push_back(MKLDNNMemoryDesc{{N, T, DC}, memory::f32, memory::ntc});
228 if (ins.size() == 3) {
229 in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
230 in_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
233 std::vector<TensorDesc> out_candidate;
235 out_candidate.push_back(out_data_d);
237 out_candidate.push_back(MKLDNNMemoryDesc{{N, T, SC}, memory::f32, memory::ntc});
239 if (outs.size() == 3) {
240 out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
241 out_candidate.emplace_back(MKLDNNMemoryDesc {S_shape, memory::f32, memory::nc});
244 createDescriptor(in_candidate, out_candidate);
247 void MKLDNNRNN::createDescriptor(const std::vector<TensorDesc> &inputDesc,
248 const std::vector<TensorDesc> &outputDesc) {
249 MKLDNNDescriptor desc(std::shared_ptr<rnn_forward::desc>(
250 new rnn_forward::desc(forward_scoring,
251 {algorithm::vanilla_lstm, algorithm::eltwise_tanh },
253 /* In Data */ in_data_d,
254 /* In State */ in_state_d,
255 /* Weights data */ w_data_d,
256 /* Weights state */ w_state_d,
258 /* Out Data */ out_data_d,
259 /* Out State */ out_state_d)));
260 descs.push_back(desc);
262 // Fill supported config
263 InferenceEngine::LayerConfig config;
264 config.dynBatchSupport = false;
265 for (size_t i = 0; i < inputDesc.size(); i++) {
266 InferenceEngine::DataConfig dataConfig;
267 dataConfig.inPlace = -1;
268 dataConfig.constant = false;
269 dataConfig.desc = inputDesc[i];
270 config.inConfs.push_back(dataConfig);
273 for (size_t i = 0; i < outputDesc.size(); i++) {
274 InferenceEngine::DataConfig dataConfig;
275 dataConfig.inPlace = -1;
276 dataConfig.constant = false;
277 dataConfig.desc = outputDesc[i];
278 config.outConfs.push_back(dataConfig);
281 supportedPrimitiveDescriptors.push_back({config, ref_any});
284 void MKLDNNRNN::createPrimitive() {
287 std::shared_ptr<rnn_forward::desc> d = descs[0];
288 rnn_forward::primitive_desc pd(*d, getEngine());
290 auto src_data_mem = getParentEdgeAt(0)->getMemoryPtr();
291 auto dst_data_mem = getChildEdgeAt(0)->getMemoryPtr();
293 // create weight blobs (data and state part)
294 auto w_data_mem = std::make_shared<MKLDNNMemory>(getEngine());
295 w_data_mem->Create(w_data_d);
296 internalBlobMemory.push_back(w_data_mem);
298 auto w_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
299 w_state_mem->Create(w_state_d);
300 internalBlobMemory.push_back(w_state_mem);
302 auto w_bias_mem = std::make_shared<MKLDNNMemory>(getEngine());
303 w_bias_mem->Create(w_bias_d);
304 internalBlobMemory.push_back(w_bias_mem);
310 * W - [gates, out_state_size, in_data_size + in_state_size]
311 * B - [gates, out_state_size]
314 * W - [1, 1, in_date_size, gates, out_state_size]
315 * R - [1, 1, in_state_size, gates, out_state_size]
316 * B - [gates, out_state_size]
319 * Caffe - IFOC, ONNX - IOFC
320 * IE - FICO, mkldnn - IFCO
323 const int gate_map[] = {1, 0, 2, 3};
325 auto ie_w_ptr = getCnnLayer()->blobs["weights"]->buffer().as<const float*>();
326 auto w_ptr = static_cast<float*>(w_data_mem->GetData());
327 auto r_ptr = static_cast<float*>(w_state_mem->GetData());
328 const int step = SC * G;
330 for (int g = 0; g < G; g++) {
331 for (int out_i = 0; out_i < SC; out_i++) {
332 float *l_w_ptr = w_ptr + gate_map[g]*SC + out_i;
333 float *l_r_ptr = r_ptr + gate_map[g]*SC+ out_i;
334 for (int in_i = 0; in_i < DC; in_i++) {
335 *l_w_ptr = *ie_w_ptr;
340 for (int in_i = 0; in_i < SC; in_i++) {
341 *l_r_ptr = *ie_w_ptr;
349 auto ie_b_ptr = getCnnLayer()->blobs["biases"]->buffer().as<const float*>();
350 auto b_ptr = static_cast<float*>(w_bias_mem->GetData());
351 for (int g = 0; g < G; g++) {
352 float *l_b_ptr = b_ptr + gate_map[g]*SC;
353 for (int out_i = 0; out_i < SC; out_i++) {
354 *l_b_ptr = *ie_b_ptr;
362 auto src_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
363 src_state_mem->Create(in_state_d);
364 internalBlobMemory.push_back(src_state_mem);
366 /* create copy/concat primitive */
367 auto src_stat_1 = getParentEdgeAt(1)->getMemory().GetPrimitive();
368 auto src_stat_2 = getParentEdgeAt(2)->getMemory().GetPrimitive();
370 auto low_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
371 low_half_state_mem->Create(
372 src_stat_1.get_primitive_desc().desc(),
373 src_state_mem->GetPrimitive().get_data_handle());
374 internalBlobMemory.push_back(low_half_state_mem);
376 auto high_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
377 high_half_state_mem->Create(
378 src_stat_2.get_primitive_desc().desc(),
379 static_cast<uint8_t*>(src_state_mem->GetPrimitive().get_data_handle()) +
380 src_stat_1.get_primitive_desc().get_size());
381 internalBlobMemory.push_back(high_half_state_mem);
383 exec_before.emplace_back(src_stat_1, low_half_state_mem->GetPrimitive());
384 exec_before.emplace_back(src_stat_2, high_half_state_mem->GetPrimitive());
387 auto dst_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
388 dst_state_mem->Create(out_state_d);
389 internalBlobMemory.push_back(dst_state_mem);
391 int idx_H = is_cell ? 0 : 1;
392 int idx_C = is_cell ? 1 : 2;
393 /* create copy/split primitive */
394 auto dst_stat_1 = getChildEdgeAt(idx_H)->getMemory().GetPrimitive();
395 auto dst_stat_2 = getChildEdgeAt(idx_C)->getMemory().GetPrimitive();
397 auto low_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
398 low_half_state_mem->Create(
399 dst_stat_1.get_primitive_desc().desc(),
400 dst_state_mem->GetPrimitive().get_data_handle());
401 internalBlobMemory.push_back(low_half_state_mem);
403 auto high_half_state_mem = std::make_shared<MKLDNNMemory>(getEngine());
404 high_half_state_mem->Create(
405 dst_stat_2.get_primitive_desc().desc(),
406 static_cast<uint8_t*>(dst_state_mem->GetPrimitive().get_data_handle()) +
407 dst_stat_1.get_primitive_desc().get_size());
408 internalBlobMemory.push_back(high_half_state_mem);
411 if (!is_cell) exec_after.emplace_back(low_half_state_mem->GetPrimitive(), dst_stat_1);
412 exec_after.emplace_back(high_half_state_mem->GetPrimitive(), dst_stat_2);
415 auto workspace_mem = std::make_shared<MKLDNNMemory>(getEngine());
416 workspace_mem->Create({}, memory::f32, memory::format_undef, nullptr); // stub, not in use
417 internalBlobMemory.push_back(workspace_mem);
419 auto p = new rnn_forward(pd,
420 /* In Data */ src_data_mem ->GetPrimitive(),
421 /* In State */ src_state_mem->GetPrimitive(),
422 /* Weights data */ w_data_mem ->GetPrimitive(),
423 /* Weights state */ w_state_mem ->GetPrimitive(),
424 /* Bias */ w_bias_mem ->GetPrimitive(),
425 /* Out Data */ dst_data_mem ->GetPrimitive(),
426 /* Out State */ dst_state_mem->GetPrimitive(),
427 /* Workspace */ workspace_mem->GetPrimitive());
432 void MKLDNNRNN::execute(mkldnn::stream strm) {
433 if (!exec_before.empty())
434 strm.submit({exec_before.begin(), exec_before.end()});
437 strm.submit({*prim});
439 if (!exec_after.empty())
440 strm.submit({exec_after.begin(), exec_after.end()});
443 } // namespace MKLDNNPlugin