1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_permute_node.h"
8 #include <mkldnn_types.h>
9 #include <mkldnn_extension_utils.h>
10 #include "ie_parallel.hpp"
12 using namespace mkldnn;
13 using namespace MKLDNNPlugin;
14 using namespace InferenceEngine;
16 MKLDNNPermuteNode::MKLDNNPermuteNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
18 void MKLDNNPermuteNode::getSupportedDescriptors() {
19 if (getParentEdges().size() != 1)
20 THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
21 if (!getChildEdges().size())
22 THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
24 auto& layer = getCnnLayer();
26 THROW_IE_EXCEPTION << "Cannot get CNNLayer.";
30 std::vector<int> layerOrder = layer->GetParamAsInts("order");
31 for (auto ord : layerOrder)
32 order.push_back(static_cast<size_t>(ord));
35 void MKLDNNPermuteNode::initSupportedPrimitiveDescriptors() {
36 if (!supportedPrimitiveDescriptors.empty())
39 InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
40 if (precision != InferenceEngine::Precision::FP32)
41 precision = InferenceEngine::Precision::FP32;
42 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
43 precision = getCnnLayer()->outData[0]->getPrecision();
44 if (precision != InferenceEngine::Precision::FP32)
45 precision = InferenceEngine::Precision::FP32;
46 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
48 InferenceEngine::LayerConfig config;
49 config.dynBatchSupport = true;
50 config.inConfs.resize(1);
51 config.outConfs.resize(1);
52 config.inConfs[0].inPlace = -1;
53 config.inConfs[0].constant = false;
54 config.outConfs[0].inPlace = -1;
55 config.outConfs[0].constant = false;
56 if (getParentEdgeAt(0)->getDims().ndims() == 4) {
57 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nchw);
58 config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nchw);
59 supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
61 auto srcDims = getParentEdgeAt(0)->getDims();
62 if (srcDims[1] % 8 == 0) {
63 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nChw8c);
64 supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
67 if (srcDims[1] % 16 == 0) {
68 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nChw16c);
69 supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
71 } else if (getParentEdgeAt(0)->getDims().ndims() == 5) {
72 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::ncdhw);
73 config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::ncdhw);
74 supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
76 auto srcDims = getParentEdgeAt(0)->getDims();
77 if (srcDims[1] % 8 == 0) {
78 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nCdhw8c);
79 supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
82 if (srcDims[1] % 16 == 0) {
83 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::nCdhw16c);
84 supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
87 config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::any);
88 config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
89 MKLDNNMemory::GetPlainFormat(getChildEdgeAt(0)->getDims()));
90 supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown});
94 void MKLDNNPermuteNode::createPrimitive() {
95 auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
96 auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
97 if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
98 THROW_IE_EXCEPTION << "Destination memory didn't allocate.";
99 if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
100 THROW_IE_EXCEPTION << "Input memory didn't allocate.";
101 if (getSelectedPrimitiveDescriptor() == nullptr)
102 THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set.";
105 static void permute_to_0231(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
106 auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
107 auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
108 src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
109 dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
110 // Supports only NCHW to NHWC
112 if (!MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat())) {
113 block_size = srcMemPtr->GetDescriptor().data.layout_desc.blocking.block_dims[1];
116 const int C = srcMemPtr->GetDims()[1];
117 const int H = srcMemPtr->GetDims()[2];
118 const int W = srcMemPtr->GetDims()[3];
121 const int src_stride = H * W * block_size;
126 for (int n = 0; n < MB; n++) {
127 for (int h = 0; h < H; h++) {
128 for (int w = 0; w < W; w++) {
129 src_off = n * C * H * W + (h * W + w) * block_size;
131 for (int c = 0; c < C; c += block_size) {
132 for (int b = 0; b < block_size; b++) {
133 dst_data[dst_off] = src_data[src_off + b];
137 src_off += src_stride;
144 static void permute_to_0213(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
145 auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
146 auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
147 src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
148 dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
150 if (!MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat())) {
151 block_size = srcMemPtr->GetDescriptor().data.layout_desc.blocking.block_dims[1];
154 const int C = srcMemPtr->GetDims()[1];
155 const int H = srcMemPtr->GetDims()[2];
156 const int W = srcMemPtr->GetDims()[3];
158 parallel_for3d(MB, C/block_size, H, [&](int n, int c, int h) {
159 for (int w = 0; w < W; w++) {
160 int src_off = n*C*H*W + (c*H*W + h*W + w)*block_size;
161 int dst_off = n*C*H*W + (h*C*W + w + c*W)*block_size;
162 for (int b = 0; b < block_size; b++) {
163 dst_data[dst_off + b] = src_data[src_off + b];
169 template <size_t scale_H = 0, size_t scale_W = 0>
170 static void permute_to_014253(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
171 auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
172 auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
173 src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
174 dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
176 const int C = srcMemPtr->GetDims()[1];
177 const int CH = scale_H > 0 ? static_cast<int>(scale_H) : srcMemPtr->GetDims()[2];
178 const int CW = scale_W > 0 ? static_cast<int>(scale_W) : srcMemPtr->GetDims()[3];
179 const int H = srcMemPtr->GetDims()[4];
180 const int W = srcMemPtr->GetDims()[5];
185 for (int n = 0; n < MB; n++) {
186 for (int c = 0; c < C; c++) {
187 for (int h = 0; h < H; h++) {
188 for (int ch = 0; ch < CH; ch++) {
189 for (int w = 0; w < W; w++) {
190 for (int cw = 0; cw < CW; cw++) {
191 src_off = n * C * CH * CW * H * W +
192 c * CH * CW * H * W +
198 dst_data[dst_off] = src_data[src_off];
208 static void permute_to_3012(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
209 auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
210 auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
211 src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
212 dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
214 const int C = srcMemPtr->GetDims()[1];
215 const int H = srcMemPtr->GetDims()[2];
216 const int W = srcMemPtr->GetDims()[3];
221 for (int w = 0; w < W; w++) {
222 for (int n = 0; n < MB; n++) {
223 for (int c = 0; c < C; c++) {
224 for (int h = 0; h < H; h++) {
225 src_off = n * C * H * W +
230 dst_data[dst_off] = src_data[src_off];
238 static void permute_to_021(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
239 auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
240 auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
241 src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
242 dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
244 const int C = srcMemPtr->GetDims()[1];
245 const int S = srcMemPtr->GetDims()[2];
247 parallel_for2d(MB, S, [&](int n, int s) {
251 for (int c = 0; c < C; c++) {
252 src_off = n * C * S +
255 dst_off = n * S * C +
259 dst_data[dst_off] = src_data[src_off];
264 static void permute_to_034152(int MB, MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
265 auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
266 auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
267 src_data += srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
268 dst_data += dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding;
270 const int DIM1 = srcMemPtr->GetDims()[1];
271 const int DIM2 = srcMemPtr->GetDims()[2];
272 const int DIM3 = srcMemPtr->GetDims()[3];
273 const int DIM4 = srcMemPtr->GetDims()[4];
274 const int DIM5 = srcMemPtr->GetDims()[5];
279 for (int n = 0; n < MB; n++) {
280 for (int dim3 = 0; dim3 < DIM3; dim3++) {
281 for (int dim4 = 0; dim4 < DIM4; dim4++) {
282 for (int dim1 = 0; dim1 < DIM1; dim1++) {
283 for (int dim5 = 0; dim5 < DIM5; dim5++) {
284 for (int dim2 = 0; dim2 < DIM2; dim2++) {
285 src_off = n * DIM1 * DIM2 * DIM3 * DIM4 * DIM5 +
286 dim1 * DIM2 * DIM3 * DIM4 * DIM5 +
287 dim2 * DIM3 * DIM4 * DIM5 +
292 dst_data[dst_off] = src_data[src_off];
302 std::multimap<InferenceEngine::SizeVector, MKLDNNPermuteNode::PermuteImpl> MKLDNNPermuteNode::OptimizedCases = {
303 {{0, 2, 3, 1}, MKLDNNPermuteNode::PermuteImpl(permute_to_0231, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
305 })}, // NCHW -> NHWC case
306 {{0, 1, 4, 2, 5, 3}, MKLDNNPermuteNode::PermuteImpl(permute_to_014253<2, 2>, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
307 return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat()) && srcMemPtr->GetDims()[2] == 2 && srcMemPtr->GetDims()[3] == 2;
308 })}, // Dense upsample convolution case (scale = 2)
309 {{0, 1, 4, 2, 5, 3}, MKLDNNPermuteNode::PermuteImpl(permute_to_014253<0, 0>, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
310 return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
311 })}, // Dense upsample convolution case (generic)
312 {{3, 0, 1, 2}, MKLDNNPermuteNode::PermuteImpl(permute_to_3012, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
313 return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
315 {{0, 2, 1, 3}, MKLDNNPermuteNode::PermuteImpl(permute_to_0213, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
316 return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
318 {{0, 2, 1}, MKLDNNPermuteNode::PermuteImpl(permute_to_021, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
319 return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
320 })}, // self attention block
321 {{0, 3, 4, 1, 5, 2}, MKLDNNPermuteNode::PermuteImpl(permute_to_034152, [](MKLDNNMemoryPtr& srcMemPtr, MKLDNNMemoryPtr& dstMemPtr) {
322 return MKLDNNMemory::IsPlainFormat(srcMemPtr->GetFormat());
323 })}, // learning-to-see-in-the-dark-sony
326 void MKLDNNPermuteNode::execute(mkldnn::stream strm) {
327 auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
328 auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
329 auto src_data = reinterpret_cast<const float *>(srcMemPtr->GetData());
330 auto dst_data = reinterpret_cast<float *>(dstMemPtr->GetData());
332 for (const auto &impl : OptimizedCases) {
333 if (impl.first == order && impl.second.isValidParams(srcMemPtr, dstMemPtr)) {
334 impl.second.execute(batchToProcess(), srcMemPtr, dstMemPtr);
339 auto srcBlob = getParentEdgeAt(0)->getBlob();
340 TensorDesc srcDesc = srcBlob->getTensorDesc();
342 SizeVector& dims = srcDesc.getDims();
343 InferenceEngine::SizeVector orderedDims;
344 for (auto ord : order) {
345 orderedDims.push_back(dims[ord]);
347 TensorDesc dstDesc(InferenceEngine::Precision::FP32, dims, {orderedDims, order});
349 int dataSize = srcBlob->size() / srcDesc.getDims()[0] * batchToProcess();
351 parallel_for(dataSize, [&](int i) {
352 dst_data[dstDesc.offset(i)] = src_data[srcDesc.offset(i)];
356 bool MKLDNNPermuteNode::created() const {
357 return getType() == Permute;