Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_permute_node.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_permute_node.h"
6 #include <ie_layers.h>
7 #include <string>
8 #include <mkldnn_types.h>
9 #include <mkldnn_extension_utils.h>
10 #include "ie_parallel.hpp"
11
12 using namespace mkldnn;
13 using namespace MKLDNNPlugin;
14 using namespace InferenceEngine;
15
16 MKLDNNPermuteNode::MKLDNNPermuteNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
17
18 void MKLDNNPermuteNode::getSupportedDescriptors() {
19     if (getParentEdges().size() != 1)
20         THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
21     if (!getChildEdges().size())
22         THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
23
24     auto& layer = getCnnLayer();
25     if (!layer) {
26         THROW_IE_EXCEPTION << "Cannot get CNNLayer.";
27     }
28
29     order.clear();
30     std::vector<int> layerOrder = layer->GetParamAsInts("order");
31     for (auto ord : layerOrder)
32         order.push_back(static_cast<size_t>(ord));
33 }
34
35 void MKLDNNPermuteNode::initSupportedPrimitiveDescriptors() {
36     if (!supportedPrimitiveDescriptors.empty())
37         return;
38
39     InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
40     if (precision != InferenceEngine::Precision::FP32)
41         precision = InferenceEngine::Precision::FP32;
42     auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
43     precision = getCnnLayer()->outData[0]->getPrecision();
44     if (precision != InferenceEngine::Precision::FP32)
45         precision = InferenceEngine::Precision::FP32;
46     auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
47
48     InferenceEngine::LayerConfig config;
49     config.dynBatchSupport = true;
50     config.inConfs.resize(1);
51     config.outConfs.resize(1);
52     config.inConfs[0].inPlace = -1;
53     config.inConfs[0].constant = false;
54     config.outConfs[0].inPlace = -1;
55     config.outConfs[0].constant = false;
56     if (getParentEdgeAt(0)->getDims().ndims() == 4) {
57         config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nchw);
58         config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nchw);
59         supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
60
61         auto srcDims = getParentEdgeAt(0)->getDims();
62         if (srcDims[1] % 8 == 0) {
63             config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nChw8c);
64             supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
65         }
66
67         if (srcDims[1] % 16 == 0) {
68             config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nChw16c);
69             supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
70         }
71     } else if (getParentEdgeAt(0)->getDims().ndims() == 5) {
72         config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::ncdhw);
73         config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::ncdhw);
74         supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
75
76         auto srcDims = getParentEdgeAt(0)->getDims();
77         if (srcDims[1] % 8 == 0) {
78             config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nCdhw8c);
79             supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
80         }
81
82         if (srcDims[1] % 16 == 0) {
83             config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nCdhw16c);
84             supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
85         }
86     } else {
87         config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::any);
88         config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
89                                                    MKLDNNMemory::GetPlainFormat(getChildEdgeAt(0)->getDims()));
90         supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
91     }
92 }
93
94 void MKLDNNPermuteNode::createPrimitive() {
95     auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
96     auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
97     if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
98         THROW_IE_EXCEPTION << "Destination memory didn't allocate.";
99     if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
100         THROW_IE_EXCEPTION << "Input memory didn't allocate.";
101     if (getSelectedPrimitiveDescriptor() == nullptr)
102         THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set.";
103 }
104
105 static void permute_to_0231(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
106     auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
107     auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
108     src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
109     dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
110     // Supports only NCHW to NHWC
111     int block_size = 1;
112     if (!MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat())) {
113         block_size = srcMemPtr->GetDescriptor().data.layout_desc.blocking.block_dims[1];
114     }
115
116     const int C = srcMemPtr->GetDims()[1];
117     const int H = srcMemPtr->GetDims()[2];
118     const int W = srcMemPtr->GetDims()[3];
119
120     // NHWC
121     const int src_stride = H * W * block_size;
122
123     int src_off = 0;
124     int dst_off = 0;
125
126     for (int n = 0; n < MB; n++) {
127         for (int h = 0; h < H; h++) {
128             for (int w = 0; w < W; w++) {
129                 src_off = n * C * H * W + (h * W + w) * block_size;
130
131                 for (int c = 0; c < C; c += block_size) {
132                     for (int b = 0; b < block_size; b++) {
133                         dst_data[dst_off] = src_data[src_off + b];
134                         dst_off++;
135                     }
136
137                     src_off += src_stride;
138                 }
139             }
140         }
141     }
142 }
143
144 static void permute_to_0213(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
145     auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
146     auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
147     src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
148     dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
149     int block_size = 1;
150     if (!MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat())) {
151         block_size = srcMemPtr->GetDescriptor().data.layout_desc.blocking.block_dims[1];
152     }
153
154     const int C = srcMemPtr->GetDims()[1];
155     const int H = srcMemPtr->GetDims()[2];
156     const int W = srcMemPtr->GetDims()[3];
157
158     parallel_for3d(MB, C/block_size, H, [&](int n, int c, int h) {
159         for (int w = 0; w < W; w++) {
160             int src_off = n*C*H*W + (c*H*W + h*W + w)*block_size;
161             int dst_off = n*C*H*W + (h*C*W + w + c*W)*block_size;
162             for (int b = 0; b < block_size; b++) {
163                 dst_data[dst_off + b] = src_data[src_off + b];
164             }
165         }
166     });
167 }
168
169 template <size_t scale_H = 0, size_t scale_W = 0>
170 static void permute_to_014253(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
171     auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
172     auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
173     src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
174     dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
175
176     const int C  = srcMemPtr->GetDims()[1];
177     const int CH  = scale_H > 0 ? static_cast<int>(scale_H) : srcMemPtr->GetDims()[2];
178     const int CW  = scale_W > 0 ? static_cast<int>(scale_W) : srcMemPtr->GetDims()[3];
179     const int H  = srcMemPtr->GetDims()[4];
180     const int W  = srcMemPtr->GetDims()[5];
181
182     int src_off = 0;
183     int dst_off = 0;
184
185     for (int n = 0; n < MB; n++) {
186         for (int c = 0; c < C; c++) {
187             for (int h = 0; h < H; h++) {
188                 for (int ch = 0; ch < CH; ch++) {
189                     for (int w = 0; w < W; w++) {
190                         for (int cw = 0; cw < CW; cw++) {
191                             src_off = n * C * CH * CW * H * W +
192                                       c * CH * CW * H * W +
193                                       ch * CW * H * W +
194                                       cw * H * W +
195                                       h * W +
196                                       w;
197
198                             dst_data[dst_off] = src_data[src_off];
199                             dst_off++;
200                         }
201                     }
202                 }
203             }
204         }
205     }
206 }
207
208 static void permute_to_3012(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
209     auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
210     auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
211     src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
212     dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
213
214     const int C  = srcMemPtr->GetDims()[1];
215     const int H  = srcMemPtr->GetDims()[2];
216     const int W  = srcMemPtr->GetDims()[3];
217
218     int src_off = 0;
219     int dst_off = 0;
220
221     for (int w = 0; w < W; w++) {
222         for (int n = 0; n < MB; n++) {
223             for (int c = 0; c < C; c++) {
224                 for (int h = 0; h < H; h++) {
225                      src_off = n * C * H * W +
226                                c * H * W +
227                                h * W +
228                                w;
229
230                      dst_data[dst_off] = src_data[src_off];
231                      dst_off++;
232                 }
233             }
234         }
235     }
236 }
237
238 static void permute_to_021(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
239     auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
240     auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
241     src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
242     dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
243
244     const int C  = srcMemPtr->GetDims()[1];
245     const int S  = srcMemPtr->GetDims()[2];
246
247     parallel_for2d(MB, S, [&](int n, int s) {
248         int src_off = 0;
249         int dst_off = 0;
250
251         for (int c = 0; c < C; c++) {
252             src_off = n * C * S +
253                       c * S +
254                       s;
255             dst_off = n * S * C +
256                       s * C +
257                       c;
258
259             dst_data[dst_off] = src_data[src_off];
260         }
261     });
262 }
263
264 static void permute_to_034152(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
265     auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
266     auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
267     src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
268     dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
269
270     const int DIM1 = srcMemPtr->GetDims()[1];
271     const int DIM2 = srcMemPtr->GetDims()[2];
272     const int DIM3 = srcMemPtr->GetDims()[3];
273     const int DIM4 = srcMemPtr->GetDims()[4];
274     const int DIM5 = srcMemPtr->GetDims()[5];
275
276     int src_off = 0;
277     int dst_off = 0;
278
279     for (int n = 0; n < MB; n++) {
280         for (int dim3 = 0; dim3 < DIM3; dim3++) {
281             for (int dim4 = 0; dim4 < DIM4; dim4++) {
282                 for (int dim1 = 0; dim1 < DIM1; dim1++) {
283                     for (int dim5 = 0; dim5 < DIM5; dim5++) {
284                         for (int dim2 = 0; dim2 < DIM2; dim2++) {
285                             src_off = n * DIM1 * DIM2 * DIM3 * DIM4 * DIM5 +
286                                       dim1 * DIM2 * DIM3 * DIM4 * DIM5 +
287                                       dim2 * DIM3 * DIM4 * DIM5 +
288                                       dim3 * DIM4 * DIM5 +
289                                       dim4 * DIM5 +
290                                       dim5;
291
292                             dst_data[dst_off] = src_data[src_off];
293                             dst_off++;
294                         }
295                     }
296                 }
297             }
298         }
299     }
300 }
301
302 std::multimap<InferenceEngine::SizeVector, MKLDNNPermuteNode::PermuteImpl> MKLDNNPermuteNode::OptimizedCases = {
303         {{0, 2, 3, 1}, MKLDNNPermuteNode::PermuteImpl(permute_to_0231, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
304             return true;
305         })},  // NCHW -> NHWC case
306         {{0, 1, 4, 2, 5, 3}, MKLDNNPermuteNode::PermuteImpl(permute_to_014253<2, 2>, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
307             return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat()) && srcMemPtr->GetDims()[2] == 2 && srcMemPtr->GetDims()[3] == 2;
308         })},  // Dense upsample convolution case (scale = 2)
309         {{0, 1, 4, 2, 5, 3}, MKLDNNPermuteNode::PermuteImpl(permute_to_014253<0, 0>, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
310             return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
311         })},  // Dense upsample convolution case (generic)
312         {{3, 0, 1, 2}, MKLDNNPermuteNode::PermuteImpl(permute_to_3012, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
313             return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
314         })},  // LPR case
315         {{0, 2, 1, 3}, MKLDNNPermuteNode::PermuteImpl(permute_to_0213, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
316             return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
317         })},  // shufflenet
318         {{0, 2, 1}, MKLDNNPermuteNode::PermuteImpl(permute_to_021, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
319             return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
320         })},  // self attention block
321         {{0, 3, 4, 1, 5, 2}, MKLDNNPermuteNode::PermuteImpl(permute_to_034152, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
322             return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
323         })},  // learning-to-see-in-the-dark-sony
324 };
325
326 void MKLDNNPermuteNode::execute(mkldnn::stream strm) {
327     auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
328     auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
329     auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
330     auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
331
332     for (const auto &impl : OptimizedCases) {
333         if (impl.first == order && impl.second.isValidParams(srcMemPtr, dstMemPtr)) {
334             impl.second.execute(batchToProcess(), srcMemPtr, dstMemPtr);
335             return;
336         }
337     }
338
339     auto srcBlob = getParentEdgeAt(0)->getBlob();
340     TensorDesc srcDesc = srcBlob->getTensorDesc();
341
342     SizeVector& dims = srcDesc.getDims();
343     InferenceEngine::SizeVector orderedDims;
344     for (auto ord : order) {
345         orderedDims.push_back(dims[ord]);
346     }
347     TensorDesc dstDesc(InferenceEngine::Precision::FP32, dims, {orderedDims, order});
348
349     int dataSize = srcBlob->size() / srcDesc.getDims()[0] * batchToProcess();
350
351     parallel_for(dataSize, [&](int i) {
352         dst_data[dstDesc.offset(i)] = src_data[srcDesc.offset(i)];
353     });
354 }
355
356 bool MKLDNNPermuteNode::created() const {
357     return getType() == Permute;
358 }