1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "mkldnn_batchnorm_node.h"
6 #include "mkldnn_depthwise_node.h"
7 #include <mkldnn_extension_utils.h>
10 using namespace mkldnn;
11 using namespace MKLDNNPlugin;
12 using namespace InferenceEngine;
14 MKLDNNBatchNormalizationNode::MKLDNNBatchNormalizationNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng): MKLDNNNode(layer, eng) {
15 internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
16 return GetVarianceDesc(primitive_desc_it.fetch());
18 internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
19 return GetMeanDesc(primitive_desc_it.fetch());
22 internalBlobDesc.emplace_back([&](primitive_desc_iterator &primitive_desc_it, size_t idx) -> MKLDNNMemoryDesc {
23 if (!fusedWithScale())
24 return MKLDNNMemoryDesc();
25 return GetScaleShiftWeightsDesc(primitive_desc_it.fetch());
29 void MKLDNNBatchNormalizationNode::getSupportedDescriptors() {
32 auto * bnLayer = dynamic_cast<BatchNormalizationLayer*>(getCnnLayer().get());
33 if (bnLayer == nullptr)
34 THROW_IE_EXCEPTION << "Cannot convert batch normalization layer.";
35 if (bnLayer->_weights == nullptr || bnLayer->_biases == nullptr) {
36 THROW_IE_EXCEPTION << "Weights/biases are empty for layer: " << bnLayer->name
37 << " used in MKLDNN node: " << getName() << "\n"
38 << "Use ReadWeights and SetWeights methods of InferenceEngine::CNNNetReader"
39 << " to load them from .bin part of the IR";
42 if (getParentEdges().size() != 1)
43 THROW_IE_EXCEPTION << "Incorrect number of input edges for layer " << getName();
44 if (!getChildEdges().size())
45 THROW_IE_EXCEPTION << "Incorrect number of output edges for layer " << getName();
47 eps = bnLayer->epsilon;
49 size_t variancesSize = MKLDNNDims(bnLayer->_weights->dims()).size();
50 size_t meansSize = MKLDNNDims(bnLayer->_biases->dims()).size();
52 if (variancesSize != meansSize && variancesSize != 1)
53 THROW_IE_EXCEPTION << "Incorrect weights and biases sizes!";
55 internalBlobs.push_back(createInternalBlob(bnLayer->_weights->dims(), true));
56 internalBlobs.push_back(createInternalBlob(bnLayer->_biases->dims(), false));
58 auto parentOutDims = getParentEdgeAt(0)->getDims();
60 if (fusedWith.size() > 1)
61 THROW_IE_EXCEPTION << "BatchNorm fusion is possible with only one layer!";
63 for (const auto &node : fusedWith) {
64 auto * scshLayer = dynamic_cast<ScaleShiftLayer*>(node->getCnnLayer().get());
65 if (scshLayer == nullptr)
66 THROW_IE_EXCEPTION << "Cannot cast to the ScaleShift layer to fuse with BatchNorm.";
68 size_t C = static_cast<size_t>(getChildEdgeAt(0)->getDims()[1]);
69 SizeVector mkldnn_weights = {2, C};
70 TensorDesc desc(scshLayer->_weights->precision(), mkldnn_weights, InferenceEngine::NC);
71 InferenceEngine::TBlob<float>::Ptr internalBlob = InferenceEngine::make_shared_blob<float>(desc);
72 internalBlob->allocate();
73 float * data = internalBlob->buffer();
75 THROW_IE_EXCEPTION << "Cannot get memory!";
77 InferenceEngine::Blob::Ptr blb = scshLayer->_weights;
79 THROW_IE_EXCEPTION << "Cannot get weights blob for node " << getName() << ".";
81 size_t weightsByteSize = blb->byteSize();
82 ie_memcpy(data, internalBlob->byteSize(), blb->buffer(), weightsByteSize);
84 blb = scshLayer->_biases;
87 memset(data, 0, weightsByteSize);
89 if (weightsByteSize != blb->byteSize())
90 THROW_IE_EXCEPTION << "ScaleShift has incorrect weights!";
91 ie_memcpy(data, internalBlob->byteSize(), blb->buffer(), weightsByteSize);
93 internalBlobs.push_back(internalBlob);
96 InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
97 if (precision != InferenceEngine::Precision::FP32)
98 precision = InferenceEngine::Precision::FP32;
99 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
101 for (auto format : getAvailableFormatsForDims(parentOutDims)) {
102 MKLDNNMemoryDesc in_candidate(parentOutDims, inputDataType, format);
103 createDescriptor({in_candidate}, {});
107 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::GetVarianceDesc(const memory::primitive_desc &primitive_desc) const {
108 memory::primitive_desc aprimitive_desc;
109 mkldnn_primitive_desc_t bndesc;
110 mkldnn_batch_normalization_desc_t *p;
111 error::wrap_c_api(mkldnn_primitive_desc_query(
112 primitive_desc.get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
113 "could not get a batch-normalization descriptor");
114 const_mkldnn_primitive_desc_t const_bndesc =
115 (p->flags & use_global_stats) ?
116 mkldnn_primitive_desc_query_pd(primitive_desc.get(),
117 mkldnn::convert_to_c(src_pd), 2) :
118 mkldnn_primitive_desc_query_pd(primitive_desc.get(),
119 mkldnn::convert_to_c(dst_pd), 2);
120 error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc,
122 "could not clone a variance primitive descriptor");
123 aprimitive_desc.reset(bndesc);
124 return MKLDNNMemoryDesc(aprimitive_desc.desc());
127 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::GetMeanDesc(const memory::primitive_desc &primitive_desc) const {
128 memory::primitive_desc aprimitive_desc;
129 mkldnn_primitive_desc_t bndesc;
130 mkldnn_batch_normalization_desc_t *p;
131 error::wrap_c_api(mkldnn_primitive_desc_query(
132 primitive_desc.get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
133 "could not get a batch-normalization descriptor");
134 const_mkldnn_primitive_desc_t const_bndesc =
135 (p->flags & use_global_stats) ?
136 mkldnn_primitive_desc_query_pd(primitive_desc.get(),
137 mkldnn::convert_to_c(src_pd), 1) :
138 mkldnn_primitive_desc_query_pd(primitive_desc.get(),
139 mkldnn::convert_to_c(dst_pd), 1);
140 error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc,
142 "could not clone a mean primitive descriptor");
143 aprimitive_desc.reset(bndesc);
144 return MKLDNNMemoryDesc(aprimitive_desc.desc());
147 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::GetScaleShiftWeightsDesc(const memory::primitive_desc &primitive_desc) const {
148 memory::primitive_desc adesc;
149 mkldnn_primitive_desc_t bndesc;
150 const_mkldnn_primitive_desc_t const_bndesc =
151 mkldnn_primitive_desc_query_pd(primitive_desc.get(),
152 mkldnn::convert_to_c(weights_pd), 0);
153 error::wrap_c_api(mkldnn_primitive_desc_clone(&bndesc,
155 "could not clone a weights primitive descriptor");
157 return MKLDNNMemoryDesc(adesc.desc());
160 bool MKLDNNBatchNormalizationNode::created() const {
161 return getType() == BatchNormalization;
164 void MKLDNNBatchNormalizationNode::createPrimitive() {
168 if (fusedWithScale()) {
169 auto prim_desc = createPrimitiveDescriptor<batch_normalization_forward::primitive_desc,
170 batch_normalization_forward::desc>();
171 prim.reset(new batch_normalization_forward(prim_desc,
172 getParentEdgeAt(0)->getMemory().GetPrimitive(),
173 (const primitive::at) internalBlobMemory[1]->GetPrimitive(),
174 (const primitive::at) internalBlobMemory[0]->GetPrimitive(),
175 (const primitive::at) internalBlobMemory[2]->GetPrimitive(),
176 getChildEdgeAt(0)->getMemory().GetPrimitive()));
178 auto prim_desc = createPrimitiveDescriptor<batch_normalization_forward::primitive_desc,
179 batch_normalization_forward::desc>();
180 prim.reset(new batch_normalization_forward(prim_desc,
181 getParentEdgeAt(0)->getMemory().GetPrimitive(),
182 (const primitive::at) internalBlobMemory[1]->GetPrimitive(),
183 (const primitive::at) internalBlobMemory[0]->GetPrimitive(),
184 getChildEdgeAt(0)->getMemory().GetPrimitive()));
188 void MKLDNNBatchNormalizationNode::createDescriptor(const std::vector<InferenceEngine::TensorDesc> &inputDesc,
189 const std::vector<InferenceEngine::TensorDesc> &outputDesc) {
190 MKLDNNMemoryDesc inDesc(inputDesc[0]);
191 if (inDesc.getDims().ndims() == 2) {
193 MKLDNNDims dims = inDesc.getDims();
194 dims.push_back(1); // H
195 dims.push_back(1); // W
196 auto format = memory::nchw;
197 inDesc = MKLDNNMemoryDesc(dims, inDesc.getDataType(), format);
200 unsigned flag = mkldnn_use_global_stats;
201 if (fusedWithScale())
202 flag |= mkldnn_use_scaleshift;
203 MKLDNNDescriptor desc(std::shared_ptr<batch_normalization_forward::desc>(
204 new batch_normalization_forward::desc(prop_kind::forward_scoring, inDesc, eps,
206 descs.push_back(desc);
209 void MKLDNNBatchNormalizationNode::initOptimalPrimitiveDescriptor() {
210 auto config = getSelectedPrimitiveDescriptor()->getConfig();
211 if (isInitConfig(config))
214 if (config.inConfs.size() != 1 || config.outConfs.size() != 1 || (!isUninitTensorDesc(config.inConfs[0].desc) &&
215 !isUninitTensorDesc(config.outConfs[0].desc) && config.inConfs[0].desc != config.outConfs[0].desc))
216 THROW_IE_EXCEPTION << "Layer " << getName() << " has incorrect selected config!";
218 if (!isUninitTensorDesc(config.inConfs[0].desc)) {
219 config.outConfs[0].desc = config.inConfs[0].desc;
220 } else if (!isUninitTensorDesc(config.outConfs[0].desc)) {
221 config.inConfs[0].desc = config.outConfs[0].desc;
223 config.outConfs[0].desc = config.inConfs[0].desc = getConfiguredInputDesc(config, 0);
226 initDescriptor(config);
229 void MKLDNNBatchNormalizationNode::initSupportedPrimitiveDescriptors() {
230 if (!supportedPrimitiveDescriptors.empty())
233 // BN primitive doesn't support strides
234 for (auto& desc : descs) {
235 primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(getEngine());
237 InferenceEngine::LayerConfig config;
238 config.dynBatchSupport = true;
239 for (size_t i = 0; i < desc.inputNumbers(); i++) {
240 InferenceEngine::DataConfig dataConfig;
241 dataConfig.inPlace = -1;
242 dataConfig.constant = false;
243 dataConfig.desc = getSrcMemDesc(itpd, i);
244 config.inConfs.push_back(dataConfig);
247 for (size_t i = 0; i < desc.outputNumbers(); i++) {
248 InferenceEngine::DataConfig dataConfig;
249 dataConfig.inPlace = canBeInPlace() ? 0 : -1;
250 dataConfig.constant = false;
251 dataConfig.desc = getDstMemDesc(itpd, i);
252 config.outConfs.push_back(dataConfig);
254 impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
256 supportedPrimitiveDescriptors.emplace_back(config, impl_type);
257 } while (itpd.next());
261 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it,
263 TensorDesc desc = MKLDNNMemoryDesc(primitive_desc_it.src_primitive_desc(idx).desc());
265 if (getParentEdgeAt(0)->getDims().ndims() == 2 && desc.getLayout() == InferenceEngine::Layout::NCHW) {
266 desc.reshape(getParentEdgeAt(idx)->getDims().ToSizeVector(), InferenceEngine::Layout::NC);
267 return MKLDNNMemoryDesc(desc);
269 if (desc.getLayout() == InferenceEngine::Layout::ANY)
270 return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
271 getParentEdgeAt(idx)->getDims().ToSizeVector(),
274 return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
275 getParentEdgeAt(idx)->getDims().ToSizeVector(),
276 desc.getBlockingDesc()));
279 MKLDNNMemoryDesc MKLDNNBatchNormalizationNode::getDstMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it,
281 TensorDesc desc = MKLDNNMemoryDesc(primitive_desc_it.dst_primitive_desc(idx).desc());
283 if (getParentEdgeAt(0)->getDims().ndims() == 2 && desc.getLayout() == InferenceEngine::Layout::NCHW) {
284 desc.reshape(getParentEdgeAt(idx)->getDims().ToSizeVector(), InferenceEngine::Layout::NC);
285 return MKLDNNMemoryDesc(desc);
287 if (desc.getLayout() == InferenceEngine::Layout::ANY)
288 return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
289 getChildEdgeAt(idx)->getDims().ToSizeVector(),
292 return MKLDNNMemoryDesc(InferenceEngine::TensorDesc(desc.getPrecision(),
293 getChildEdgeAt(idx)->getDims().ToSizeVector(),
294 desc.getBlockingDesc()));