Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_reorder_node.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_reorder_node.h"
6 #include <memory>
7 #include <string>
8 #include <algorithm>
9 #include <mkldnn_types.h>
10 #include <mkldnn_extension_utils.h>
11 #include "ie_parallel.hpp"
12
13 using namespace mkldnn;
14 using namespace MKLDNNPlugin;
15
16 MKLDNNReorderNode::MKLDNNReorderNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {
17 }
18
19 void MKLDNNReorderNode::getSupportedDescriptors() {
20     if (outDims.empty() && output.getLayout() != InferenceEngine::Layout::ANY)
21         outDims.push_back(MKLDNNDims(output.getDims()));
22     if (inDims.empty() && input.getLayout() != InferenceEngine::Layout::ANY)
23         inDims.push_back(MKLDNNDims(input.getDims()));
24     if (getParentEdges().size() != 1)
25         THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
26     if (getChildEdges().empty())
27         THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
28 }
29
30 void MKLDNNReorderNode::initSupportedPrimitiveDescriptors() {
31     if (!supportedPrimitiveDescriptors.empty())
32         return;
33
34     auto inputDataType = MKLDNNMemoryDesc(input).getDataType();
35     auto outputDataType = MKLDNNMemoryDesc(output).getDataType();
36
37     auto parent = getParentEdgeAt(0)->getParent();
38     auto child = getChildEdgeAt(0)->getChild();
39
40     InferenceEngine::LayerConfig config;
41     config.dynBatchSupport = true;
42     config.inConfs.resize(1);
43     config.outConfs.resize(1);
44     config.inConfs[0].inPlace = -1;
45     config.inConfs[0].constant = false;
46     config.outConfs[0].inPlace = -1;
47     config.outConfs[0].constant = false;
48     if (input.getLayout() != InferenceEngine::Layout::ANY && output.getLayout() != InferenceEngine::Layout::ANY) {
49         config.inConfs[0].desc = input;
50         config.outConfs[0].desc = output;
51     } else if (parent->getSelectedPrimitiveDescriptor() != nullptr &&
52                child->getSelectedPrimitiveDescriptor() != nullptr) {
53         config.inConfs[0].desc = parent->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc;
54         config.outConfs[0].desc = child->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc;
55     } else {
56         config.inConfs[0].desc = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, memory::format::any);
57         config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::format::any);
58     }
59
60     supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::reorder);
61 }
62
63 void MKLDNNReorderNode::createPrimitive() {
64     auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
65     auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
66     if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
67         THROW_IE_EXCEPTION << "Destination memory didn't allocate.";
68     if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
69         THROW_IE_EXCEPTION << "Input memory didn't allocate.";
70     if (getSelectedPrimitiveDescriptor() == nullptr)
71         THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set.";
72
73     createReorderPrimitive(srcMemPtr->GetDescriptor(), srcMemPtr->GetPrimitive().get_data_handle(),
74             dstMemPtr->GetDescriptor(), dstMemPtr->GetPrimitive().get_data_handle());
75 }
76
77 void MKLDNNReorderNode::createReorderPrimitive(mkldnn::memory::desc srcDesc, void* srcPtr, mkldnn::memory::desc dstDesc, void* dstPtr) {
78     src_blocked = std::make_shared<MKLDNNMemory>(getEngine());
79     src_blocked->Create(srcDesc, srcPtr);
80
81     dst_blocked = std::make_shared<MKLDNNMemory>(getEngine());
82     dst_blocked->Create(dstDesc, dstPtr);
83
84     mkldnn::primitive_attr attr;
85
86     if (_scales) {
87         std::vector<float> scales;
88
89         float* scaleData = static_cast<float*>(_scales->buffer());
90
91         for (size_t i = 0; i < _scales->size(); i++) {
92             scales.push_back(scaleData[i]);
93         }
94
95         int mask = 0;
96         int oc_dim_id = 1;
97         mask = 1 << oc_dim_id;
98
99         attr.set_output_scales(mask, scales);
100         attr.set_int_output_round_mode(round_nearest);
101     }
102
103     try {
104         // No autoblocking. Reorder can be applied as is
105         reorder::primitive_desc pd = reorder::primitive_desc(src_blocked->GetPrimitiveDescriptor(), dst_blocked->GetPrimitiveDescriptor(), attr);
106
107         prim.reset(new mkldnn::reorder(pd, src_blocked->GetPrimitive(), dst_blocked->GetPrimitive()));
108     } catch (...) {}
109 }
110
111 const std::vector<impl_desc_type>& MKLDNNReorderNode::getPrimitivesPriority() {
112     implPriorities = {impl_desc_type::reorder};
113     return implPriorities;
114 }
115
116 bool MKLDNNReorderNode::created() const {
117     return getType() == Reorder;
118 }
119
120 void MKLDNNReorderNode::execute(mkldnn::stream strm) {
121     src_blocked->GetPrimitivePtr()->set_data_handle(getParentEdgeAt(0)->getMemory().GetPrimitive().get_data_handle());
122     dst_blocked->GetPrimitivePtr()->set_data_handle(getChildEdgeAt(0)->getMemory().GetPrimitive().get_data_handle());
123     MKLDNNNode::execute(strm);
124 }
125
126 void MKLDNNReorderNode::setDynamicBatchLim(int lim) {
127     dynBatchLim = lim;
128     if (prim) {
129         auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
130         auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
131         memory::desc src_d = srcMemPtr->GetDescriptor();
132         memory::desc dst_d = dstMemPtr->GetDescriptor();
133         void *src_data_hdl = srcMemPtr->GetPrimitive().get_data_handle();
134         void *dst_data_hdl = dstMemPtr->GetPrimitive().get_data_handle();
135
136         src_d.data.dims[0] = batchToProcess();
137         src_d.data.layout_desc.blocking.padding_dims[0] = batchToProcess();
138
139         dst_d.data.dims[0] = batchToProcess();
140         dst_d.data.layout_desc.blocking.padding_dims[0] = batchToProcess();
141
142         createReorderPrimitive(src_d, src_data_hdl, dst_d, dst_data_hdl);
143     }
144 }