Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / nodes / mkldnn_batchnorm_node.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_batchnorm_node.h"
6 #include "mkldnn_depthwise_node.h"
7 #include <mkldnn_extension_utils.h>
8 #include "ie_memcpy.h"
9
10 using namespace mkldnn;
11 using namespace MKLDNNPlugin;
12 using namespace InferenceEngine;
13
14 MKLDNNBatchNormalizationNode::MKLDNNBatchNormalizationNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng): MKLDNNNode(layer, eng) {
15     internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
16         return GetVarianceDesc(primitive_desc_it.fetch());
17     });
18     internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
19         return GetMeanDesc(primitive_desc_it.fetch());
20     });
21
22     internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
23         if (!fusedWithScale())
24             return MKLDNNMemoryDesc();
25         return GetScaleShiftWeightsDesc(primitive_desc_it.fetch());
26     });
27 }
28
29 void MKLDNNBatchNormalizationNode::getSupportedDescriptors() {
30     if (!descs.empty())
31         return;
32     auto * bnLayer = dynamic_cast<BatchNormalizationLayer*>(getCnnLayer().get());
33     if (bnLayer == nullptr)
34         THROW_IE_EXCEPTION << "Cannot convert batch normalization layer.";
35     if (bnLayer->_weights == nullptr || bnLayer->_biases == nullptr) {
36         THROW_IE_EXCEPTION << "Weights/biases are empty for layer: " << bnLayer->name
37                            << " used in MKLDNN node: " << getName() << "\n"
38                            << "Use ReadWeights and SetWeights methods of InferenceEngine::CNNNetReader"
39                            << " to load them from .bin part of the IR";
40     }
41
42     if (getParentEdges().size() != 1)
43         THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
44     if (!getChildEdges().size())
45         THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
46
47     eps = bnLayer->epsilon;
48
49     size_t variancesSize = MKLDNNDims(bnLayer->_weights->dims()).size();
50     size_t meansSize = MKLDNNDims(bnLayer->_biases->dims()).size();
51
52     if (variancesSize != meansSize && variancesSize != 1)
53         THROW_IE_EXCEPTION << "Incorrect weights and biases sizes!";
54
55     internalBlobs.push_back(createInternalBlob(bnLayer->_weights->dims(), true));
56     internalBlobs.push_back(createInternalBlob(bnLayer->_biases->dims(), false));
57
58     auto parentOutDims = getParentEdgeAt(0)->getDims();
59
60     if (fusedWith.size() > 1)
61         THROW_IE_EXCEPTION << "BatchNorm fusion is possible with only one layer!";
62
63     for (const auto &node : fusedWith) {
64         auto * scshLayer = dynamic_cast<ScaleShiftLayer*>(node->getCnnLayer().get());
65         if (scshLayer == nullptr)
66             THROW_IE_EXCEPTION << "Cannot cast to the ScaleShift layer to fuse with BatchNorm.";
67
68         size_t C = static_cast<size_t>(getChildEdgeAt(0)->getDims()[1]);
69         SizeVector mkldnn_weights = {2, C};
70         TensorDesc desc(scshLayer->_weights->precision(), mkldnn_weights, InferenceEngine::NC);
71         InferenceEngine::TBlob<float>::Ptr internalBlob = InferenceEngine::make_shared_blob<float>(desc);
72         internalBlob->allocate();
73         float * data = internalBlob->buffer();
74         if (data == nullptr)
75             THROW_IE_EXCEPTION << "Cannot get memory!";
76
77         InferenceEngine::Blob::Ptr blb = scshLayer->_weights;
78         if (blb == nullptr)
79             THROW_IE_EXCEPTION << "Cannot get weights blob for node " << getName() << ".";
80
81         size_t weightsByteSize = blb->byteSize();
82         ie_memcpy(data, internalBlob->byteSize(), blb->buffer(), weightsByteSize);
83         data += blb->size();
84         blb = scshLayer->_biases;
85
86         if (blb == nullptr) {
87             memset(data, 0, weightsByteSize);
88         } else {
89             if (weightsByteSize != blb->byteSize())
90                 THROW_IE_EXCEPTION << "ScaleShift has incorrect weights!";
91             ie_memcpy(data, internalBlob->byteSize(), blb->buffer(), weightsByteSize);
92         }
93         internalBlobs.push_back(internalBlob);
94     }
95
96     InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
97     if (precision != InferenceEngine::Precision::FP32)
98         precision = InferenceEngine::Precision::FP32;
99     auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
100
101     for (auto format : getAvailableFormatsForDims(parentOutDims)) {
102         MKLDNNMemoryDesc in_candidate(parentOutDims, inputDataType, format);
103         createDescriptor({in_candidate}, {});
104     }
105 }
106
107 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::GetVarianceDesc(const memory::primitive_desc &primitive_desc) const {
108     memory::primitive_desc aprimitive_desc;
109     mkldnn_primitive_desc_t bndesc;
110     mkldnn_batch_normalization_desc_t *p;
111     error::wrap_c_api(mkldnn_primitive_desc_query(
112             primitive_desc.get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
113                       "could not get a batch-normalization descriptor");
114     const_mkldnn_primitive_desc_t const_bndesc =
115             (p->flags & use_global_stats) ?
116             mkldnn_primitive_desc_query_pd(primitive_desc.get(),
117                                                   mkldnn::convert_to_c(src_pd), 2) :
118             mkldnn_primitive_desc_query_pd(primitive_desc.get(),
119                                                   mkldnn::convert_to_c(dst_pd), 2);
120     error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc,
121                                                          const_bndesc),
122                       "could not clone a variance primitive descriptor");
123     aprimitive_desc.reset(bndesc);
124     return MKLDNNMemoryDesc(aprimitive_desc.desc());
125 }
126
127 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::GetMeanDesc(const memory::primitive_desc &primitive_desc) const {
128     memory::primitive_desc aprimitive_desc;
129     mkldnn_primitive_desc_t bndesc;
130     mkldnn_batch_normalization_desc_t *p;
131     error::wrap_c_api(mkldnn_primitive_desc_query(
132             primitive_desc.get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
133                       "could not get a batch-normalization descriptor");
134     const_mkldnn_primitive_desc_t const_bndesc =
135             (p->flags & use_global_stats) ?
136             mkldnn_primitive_desc_query_pd(primitive_desc.get(),
137                                                   mkldnn::convert_to_c(src_pd), 1) :
138             mkldnn_primitive_desc_query_pd(primitive_desc.get(),
139                                                   mkldnn::convert_to_c(dst_pd), 1);
140     error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc,
141                                                          const_bndesc),
142                       "could not clone a mean primitive descriptor");
143     aprimitive_desc.reset(bndesc);
144     return MKLDNNMemoryDesc(aprimitive_desc.desc());
145 }
146
147 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::GetScaleShiftWeightsDesc(const memory::primitive_desc &primitive_desc) const {
148     memory::primitive_desc adesc;
149     mkldnn_primitive_desc_t bndesc;
150     const_mkldnn_primitive_desc_t const_bndesc =
151             mkldnn_primitive_desc_query_pd(primitive_desc.get(),
152                                            mkldnn::convert_to_c(weights_pd), 0);
153     error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc,
154                                                   const_bndesc),
155                       "could not clone a weights primitive descriptor");
156     adesc.reset(bndesc);
157     return MKLDNNMemoryDesc(adesc.desc());
158 }
159
160 bool MKLDNNBatchNormalizationNode::created() const {
161     return getType() == BatchNormalization;
162 }
163
164 void MKLDNNBatchNormalizationNode::createPrimitive() {
165     if (prim)
166         return;
167
168     if (fusedWithScale()) {
169         auto prim_desc = createPrimitiveDescriptor<batch_normalization_forward::primitive_desc,
170                 batch_normalization_forward::desc>();
171         prim.reset(new batch_normalization_forward(prim_desc,
172                                                    getParentEdgeAt(0)->getMemory().GetPrimitive(),
173                                                    (const primitive::at) internalBlobMemory[1]->GetPrimitive(),
174                                                    (const primitive::at) internalBlobMemory[0]->GetPrimitive(),
175                                                    (const primitive::at) internalBlobMemory[2]->GetPrimitive(),
176                                                    getChildEdgeAt(0)->getMemory().GetPrimitive()));
177     }  else {
178         auto prim_desc = createPrimitiveDescriptor<batch_normalization_forward::primitive_desc,
179                 batch_normalization_forward::desc>();
180         prim.reset(new batch_normalization_forward(prim_desc,
181                                                    getParentEdgeAt(0)->getMemory().GetPrimitive(),
182                                                    (const primitive::at) internalBlobMemory[1]->GetPrimitive(),
183                                                    (const primitive::at) internalBlobMemory[0]->GetPrimitive(),
184                                                    getChildEdgeAt(0)->getMemory().GetPrimitive()));
185     }
186 }
187
188 void MKLDNNBatchNormalizationNode::createDescriptor(const std::vector<InferenceEngine::TensorDesc> &inputDesc,
189                                                     const std::vector<InferenceEngine::TensorDesc> &outputDesc) {
190     MKLDNNMemoryDesc inDesc(inputDesc[0]);
191     if (inDesc.getDims().ndims() == 2) {
192         // Make it 4D
193         MKLDNNDims dims = inDesc.getDims();
194         dims.push_back(1);  // H
195         dims.push_back(1);  // W
196         auto format = memory::nchw;
197         inDesc = MKLDNNMemoryDesc(dims, inDesc.getDataType(), format);
198     }
199
200     unsigned flag = mkldnn_use_global_stats;
201     if (fusedWithScale())
202         flag |= mkldnn_use_scaleshift;
203     MKLDNNDescriptor desc(std::shared_ptr<batch_normalization_forward::desc>(
204             new batch_normalization_forward::desc(prop_kind::forward_scoring, inDesc, eps,
205                                                   flag)));
206     descs.push_back(desc);
207 }
208
209 void MKLDNNBatchNormalizationNode::initOptimalPrimitiveDescriptor() {
210     auto config = getSelectedPrimitiveDescriptor()->getConfig();
211     if (isInitConfig(config))
212         return;
213
214     if (config.inConfs.size() != 1 || config.outConfs.size() != 1 || (!isUninitTensorDesc(config.inConfs[0].desc) &&
215             !isUninitTensorDesc(config.outConfs[0].desc) && config.inConfs[0].desc != config.outConfs[0].desc))
216         THROW_IE_EXCEPTION << "Layer " << getName() << " has incorrect selected config!";
217
218     if (!isUninitTensorDesc(config.inConfs[0].desc)) {
219         config.outConfs[0].desc = config.inConfs[0].desc;
220     } else if (!isUninitTensorDesc(config.outConfs[0].desc)) {
221         config.inConfs[0].desc = config.outConfs[0].desc;
222     } else {
223         config.outConfs[0].desc = config.inConfs[0].desc = getConfiguredInputDesc(config, 0);
224     }
225
226     initDescriptor(config);
227 }
228
229 void MKLDNNBatchNormalizationNode::initSupportedPrimitiveDescriptors() {
230     if (!supportedPrimitiveDescriptors.empty())
231         return;
232
233     // BN primitive doesn't support strides
234     for (auto& desc : descs) {
235         primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(getEngine());
236         do {
237             InferenceEngine::LayerConfig config;
238             config.dynBatchSupport = true;
239             for (size_t i = 0; i < desc.inputNumbers(); i++) {
240                 InferenceEngine::DataConfig dataConfig;
241                 dataConfig.inPlace = -1;
242                 dataConfig.constant = false;
243                 dataConfig.desc = getSrcMemDesc(itpd, i);
244                 config.inConfs.push_back(dataConfig);
245             }
246
247             for (size_t i = 0; i < desc.outputNumbers(); i++) {
248                 InferenceEngine::DataConfig dataConfig;
249                 dataConfig.inPlace = canBeInPlace() ? 0 : -1;
250                 dataConfig.constant = false;
251                 dataConfig.desc = getDstMemDesc(itpd, i);
252                 config.outConfs.push_back(dataConfig);
253             }
254             impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
255
256             supportedPrimitiveDescriptors.emplace_back(config, impl_type);
257         } while (itpd.next());
258     }
259 }
260
261 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it,
262                                                              size_t idx) {
263     TensorDesc desc = MKLDNNMemoryDesc(primitive_desc_it.src_primitive_desc(idx).desc());
264
265     if (getParentEdgeAt(0)->getDims().ndims() == 2 && desc.getLayout() == InferenceEngine::Layout::NCHW) {
266         desc.reshape(getParentEdgeAt(idx)->getDims().ToSizeVector(), InferenceEngine::Layout::NC);
267         return MKLDNNMemoryDesc(desc);
268     }
269     if (desc.getLayout() == InferenceEngine::Layout::ANY)
270         return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
271                                                             getParentEdgeAt(idx)->getDims().ToSizeVector(),
272                                                             desc.getLayout()));
273     else
274         return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
275                                                             getParentEdgeAt(idx)->getDims().ToSizeVector(),
276                                                             desc.getBlockingDesc()));
277 }
278
279 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::getDstMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it,
280                                                              size_t idx) {
281     TensorDesc desc =  MKLDNNMemoryDesc(primitive_desc_it.dst_primitive_desc(idx).desc());
282
283     if (getParentEdgeAt(0)->getDims().ndims() == 2 && desc.getLayout() == InferenceEngine::Layout::NCHW) {
284         desc.reshape(getParentEdgeAt(idx)->getDims().ToSizeVector(), InferenceEngine::Layout::NC);
285         return MKLDNNMemoryDesc(desc);
286     }
287     if (desc.getLayout() == InferenceEngine::Layout::ANY)
288         return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
289                                                             getChildEdgeAt(idx)->getDims().ToSizeVector(),
290                                                             desc.getLayout()));
291     else
292         return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
293                                                             getChildEdgeAt(idx)->getDims().ToSizeVector(),
294                                                             desc.getBlockingDesc()));
295 }