Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_fullyconnected_node.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_fullyconnected_node.h"
6 #include "mkldnn_activation_node.h"
7 #include "desc_iterator.hpp"
8 #include <ie_layers.h>
9 #include <string>
10 #include <vector>
11 #include <mkldnn_extension_utils.h>
12 #include <mkldnn.hpp>
13
14 using namespace mkldnn;
15 using namespace MKLDNNPlugin;
16 using namespace InferenceEngine;
17
18 MKLDNNFullyConnectedNode::MKLDNNFullyConnectedNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {
19     internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
20         return MKLDNNMemoryDesc(primitive_desc_it.weights_primitive_desc(0).desc());
21     });
22     internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
23         if (internalBlobs.size() <= 1)
24             return MKLDNNMemoryDesc();
25         return MKLDNNMemoryDesc(primitive_desc_it.weights_primitive_desc(1).desc());
26     });
27
28     auto ws = layer->blobs.find("w-scale");
29     if (ws != layer->blobs.end()) {
30         wScale = ws->second;
31     }
32
33     // Trying to find oi-scale
34     if (getCnnLayer()->type == "FullyConnected" && getCnnLayer()->precision == Precision::I8) {
35         auto ois = layer->blobs.find("oi-scale");
36         if ((getCnnLayer()->outData[0]->getPrecision() == Precision::I8 || getCnnLayer()->outData[0]->getPrecision() == Precision::U8)
37             && ois == layer->blobs.end()) {
38             THROW_IE_EXCEPTION << "Internal error of graph quantization - mismatch of intermediate scales and next layer type for fully connected "
39                 << getCnnLayer()->name;
40         }
41         if (ois != layer->blobs.end()) {
42             // If we can find an oi-scale, then the next layer has to be an INT8.
43             oScale = ois->second;
44         }
45     }
46 }
47
48 void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
49     if (!descs.empty())
50         return;
51
52     InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
53     auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
54     precision = getCnnLayer()->outData[0]->getPrecision();
55     auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
56
57     auto * fcLayer = dynamic_cast<FullyConnectedLayer*>(getCnnLayer().get());
58     if (fcLayer == nullptr)
59         THROW_IE_EXCEPTION << "Cannot convert fully connected layer.";
60     if (fcLayer->_weights == nullptr) {
61         THROW_IE_EXCEPTION << "Weights are empty for layer: " << fcLayer->name
62                            << " used in MKLDNN node: " << getName() << "\n"
63                            << "Use ReadWeights and SetWeights methods of InferenceEngine::CNNNetReader"
64                            << " to load them from .bin part of the IR";
65     }
66
67     if (getParentEdges().size() != 1)
68         THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
69     if (getParentEdges().empty())
70         THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
71
72     MKLDNNDims inDims(fcLayer->input()->getDims());
73
74     if (inDims.ndims() == 2) {
75         weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims.size(1))};
76     } else if (inDims.ndims() == 4) {
77         weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims[1]), static_cast<size_t>(inDims[2]),
78                        static_cast<size_t>(inDims[3])};
79     } else if (inDims.ndims() == 5) {
80         weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims[1]), static_cast<size_t>(inDims[2]),
81                        static_cast<size_t>(inDims[3]), static_cast<size_t>(inDims[4])};
82     } else {
83         THROW_IE_EXCEPTION << "Unsupported source format for FC layer. Expected 5, 4 or 2, got: "
84                            << inDims.ndims() << " dims.";
85     }
86
87     internalBlobs.push_back(createInternalBlob(weightsDims, true));
88
89     bool withBiases = (fcLayer->_biases != nullptr && fcLayer->_biases->size() != 0);
90     if (withBiases) {
91         biasesDims.push_back(static_cast<int>(fcLayer->_out_num));
92         internalBlobs.push_back(createInternalBlob(biasesDims, false));
93     }
94
95     Blob::Ptr weights = this->getCnnLayer()->blobs.find("weights")->second;
96     if (weights->precision() == Precision::I8) {
97         // The weights blob has incorrect dims, so we have to fix it
98         TensorDesc wdesc = internalBlobs[0]->getTensorDesc();
99         wdesc.setPrecision(Precision::I8);
100         InferenceEngine::TBlob<int8_t>::Ptr reshapedInt8Weights =
101                 InferenceEngine::TBlob<int8_t>::Ptr(
102                         new InferenceEngine::TBlob<int8_t>(wdesc, static_cast<int8_t*>(weights->buffer()), weights->byteSize()));
103
104         internalBlobs[0] = reshapedInt8Weights;
105         if (withBiases) {
106             Blob::Ptr biases = this->getCnnLayer()->blobs.find("biases")->second;
107             TensorDesc bdesc = internalBlobs[1]->getTensorDesc();
108             bdesc.setPrecision(Precision::I32);
109             InferenceEngine::TBlob<int32_t>::Ptr reshapedInt32Biases =
110                     InferenceEngine::TBlob<int32_t>::Ptr(
111                             new InferenceEngine::TBlob<int32_t>(bdesc, static_cast<int32_t*>(biases->buffer()), biases->byteSize()));
112             internalBlobs[1] = reshapedInt32Biases;
113         }
114     }
115
116     for (auto format : getAvailableFormatsForDims(getParentEdgeAt(0)->getDims())) {
117         MKLDNNMemoryDesc in_candidate(inDims, inputDataType, format);
118         MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType, memory::any);
119
120         createDescriptor({in_candidate}, {out_candidate});
121     }
122 }
123
124 void MKLDNNFullyConnectedNode::createPrimitive() {
125     if (prim)
126         return;
127
128     std::shared_ptr<mkldnn::primitive_attr> attr = initPrimitiveAttr();
129     std::shared_ptr<inner_product_forward::primitive_desc> prim_desc;
130     if (attr == nullptr) {
131         prim_desc = std::make_shared<inner_product_forward::primitive_desc>(
132                 createPrimitiveDescriptor<inner_product_forward::primitive_desc, inner_product_forward::desc>(*attr));
133     } else {
134         prim_desc = std::make_shared<inner_product_forward::primitive_desc>(
135                 createPrimitiveDescriptor<inner_product_forward::primitive_desc, inner_product_forward::desc>(*attr));
136     }
137
138     if (internalBlobs.size() > 1) {
139         prim.reset(new inner_product_forward(*prim_desc,
140                                              getParentEdgeAt(0)->getMemory().GetPrimitive(),
141                                              internalBlobMemory[0]->GetPrimitive(),
142                                              internalBlobMemory[1]->GetPrimitive(),
143                                              getChildEdgeAt(0)->getMemory().GetPrimitive()));
144     } else {
145         prim.reset(new inner_product_forward(*prim_desc,
146                                              getParentEdgeAt(0)->getMemory().GetPrimitive(),
147                                              internalBlobMemory[0]->GetPrimitive(),
148                                              getChildEdgeAt(0)->getMemory().GetPrimitive()));
149     }
150 }
151
152 bool MKLDNNFullyConnectedNode::created() const {
153     return getType() == FullyConnected ||
154             getType() == FullyConnected_Activation;
155 }
156
157 memory::format MKLDNNFullyConnectedNode::weightsFormatForSrcFormat(memory::format sourceFormat) {
158     switch (sourceFormat) {
159         case memory::format::x:
160             return memory::format::x;
161         case memory::format::nc:
162             return memory::format::oi;
163         case memory::format::nchw:
164             return memory::format::oihw;
165         case memory::format::ncdhw:
166             return memory::format::oidhw;
167         case memory::format::nChw8c:
168             return memory::format::oIhw8i;
169         case memory::format::nCdhw8c:
170             return memory::format::oIdhw8i;
171         case memory::format::nChw16c:
172             return memory::format::oIhw16i;
173         case memory::format::nCdhw16c:
174             return memory::format::oIdhw16i;
175         default:
176             THROW_IE_EXCEPTION << "Unsupported source format for node " << getName();
177     }
178 }
179
180 const std::vector<impl_desc_type>& MKLDNNFullyConnectedNode::getPrimitivesPriority() {
181     std::vector<impl_desc_type> priorities = {
182             impl_desc_type::unknown,
183             impl_desc_type::gemm_blas,
184             impl_desc_type::gemm_avx512,
185             impl_desc_type::gemm_avx2,
186             impl_desc_type::gemm_avx,
187             impl_desc_type::gemm_sse42,
188             impl_desc_type::gemm_any,
189             impl_desc_type::gemm,
190             impl_desc_type::jit_uni_dw,
191             impl_desc_type::jit_uni_1x1,
192             impl_desc_type::jit_uni,
193             impl_desc_type::jit_avx512_dw,
194             impl_desc_type::jit_avx512_1x1,
195             impl_desc_type::jit_avx512,
196             impl_desc_type::jit_avx2_dw,
197             impl_desc_type::jit_avx2_1x1,
198             impl_desc_type::jit_avx2,
199             impl_desc_type::jit_avx_dw,
200             impl_desc_type::jit_avx_1x1,
201             impl_desc_type::jit_avx,
202             impl_desc_type::jit_sse42_dw,
203             impl_desc_type::jit_sse42_1x1,
204             impl_desc_type::jit_sse42,
205             impl_desc_type::ref,
206     };
207     for (const auto& impl : priorities) {
208         if (std::find(implPriorities.begin(), implPriorities.end(), impl) == implPriorities.end())
209             implPriorities.push_back(impl);
210     }
211     return implPriorities;
212 }
213
214 std::shared_ptr<mkldnn::primitive_attr> MKLDNNFullyConnectedNode::initPrimitiveAttr() const {
215     auto attr = std::make_shared<mkldnn::primitive_attr>(mkldnn::primitive_attr());
216     bool scaled = false;
217     if (wScale != nullptr) {
218        float* wScaleData = static_cast<float*>(wScale->buffer());
219
220        std::vector<float> oScaleDataVector;
221        if (getCnnLayer()->precision == Precision::I8 && getCnnLayer()->outData[0]->getPrecision() != Precision::FP32) {
222            float *oScaleData = static_cast<float *>(oScale->buffer());
223
224            for (size_t c = 0; c < wScale->size(); c++) {
225                oScaleDataVector.push_back(wScaleData[c] / oScaleData[c]);
226            }
227        } else {
228            for (size_t c = 0; c < wScale->size(); c++) {
229                oScaleDataVector.push_back(wScaleData[c]);
230            }
231        }
232
233        attr->set_int_output_round_mode(mkldnn::round_nearest);
234        attr->set_output_scales(1 << 1 /*through C dim*/, oScaleDataVector);
235     }
236     mkldnn::post_ops ops;
237     for (auto &node : fusedWith) {
238         auto* activationNode = dynamic_cast<MKLDNNActivationNode *>(node.get());
239         if (activationNode) {
240             ops.append_eltwise(1.0, activationNode->getAlgorithm(), activationNode->getAlpha(),
241                                activationNode->getBeta());
242         }
243         attr->set_post_ops(ops);
244     }
245     return attr;
246 }
247
248 void MKLDNNFullyConnectedNode::createDescriptor(const std::vector<InferenceEngine::TensorDesc> &inputDesc,
249                                                 const std::vector<InferenceEngine::TensorDesc> &outputDesc) {
250     TensorDesc inDesc = inputDesc[0], outDesc = outputDesc[0];
251     mkldnn::memory::data_type wdt = MKLDNNExtensionUtils::IEPrecisionToDataType(inDesc.getPrecision());
252     mkldnn::memory::data_type bdt = MKLDNNExtensionUtils::IEPrecisionToDataType(inDesc.getPrecision());
253
254     Blob::Ptr weights = this->getCnnLayer()->blobs.find("weights")->second;
255
256     if (weights->precision() == Precision::I8) {
257         wdt = memory::s8;
258         bdt = memory::s32;
259
260         Precision outPrec;
261         if (getCnnLayer()->outData[0]->getPrecision() == Precision::FP32) {
262             outPrec = Precision::FP32;
263         } else {
264             // define precision accordninly normalizer
265             // TODO(amalyshe) do we need to have separate flow for last in int8 chain or not?
266             outPrec = outDesc.getPrecision();
267         }
268
269         inDesc = TensorDesc(inDesc.getPrecision() , inputDesc[0].getDims(), inputDesc[0].getBlockingDesc());
270         outDesc = TensorDesc(outPrec, outputDesc[0].getDims(), Layout::NC/*, outputDesc[0].getBlockingDesc()*/);
271     }
272
273     MKLDNNMemoryDesc in_candidate(inDesc);
274     MKLDNNMemoryDesc out_candidate(outDesc);
275
276     memory::format weights_fmt = weightsFormatForSrcFormat(in_candidate.getFormat());
277
278     MKLDNNMemoryDesc wgh_candidate(MKLDNNDims(weightsDims), wdt, weights_fmt);
279
280     if (internalBlobs.size() > 1) {
281         MKLDNNMemoryDesc bias_candidate(MKLDNNDims(biasesDims), bdt, memory::any);
282         MKLDNNDescriptor desc(std::shared_ptr<inner_product_forward::desc>(
283                 new inner_product_forward::desc(prop_kind::forward_scoring, in_candidate, wgh_candidate,
284                                                 bias_candidate, out_candidate)));
285         descs.push_back(desc);
286     } else {
287         MKLDNNDescriptor desc(std::shared_ptr<inner_product_forward::desc>(
288                 new inner_product_forward::desc(prop_kind::forward_scoring, in_candidate, wgh_candidate,
289                                                 out_candidate)));
290         descs.push_back(desc);
291     }
292 }