1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <mkldnn_extension_mngr.h>
7 #include <mkldnn_extension_utils.h>
8 #include "mkldnn_generic_node.h"
11 #include <blob_factory.hpp>
13 using namespace mkldnn;
14 using namespace MKLDNNPlugin;
16 MKLDNNGenericNode::MKLDNNGenericNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
18 void MKLDNNGenericNode::getSupportedDescriptors() {
19 if (!genericPrimitive && !extFactory) {
20 std::string type = getCnnLayer() ? getCnnLayer()->type : "Generic";
21 THROW_IE_EXCEPTION << "Cannot get generic primitive for layer: " << getName() << " with type: " << type;
23 if (genericPrimitive && extFactory) {
28 void MKLDNNGenericNode::initSupportedPrimitiveDescriptors() {
29 if (!supportedPrimitiveDescriptors.empty())
32 InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
33 if (precision != InferenceEngine::Precision::FP32)
34 precision = InferenceEngine::Precision::FP32;
35 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
36 precision = getCnnLayer()->outData[0]->getPrecision();
37 if (precision != InferenceEngine::Precision::FP32)
38 precision = InferenceEngine::Precision::FP32;
39 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
41 if (genericPrimitive) {
42 std::vector<InferenceEngine::MKLDNNPlugin::MKLDNNGenericFormats> formats = genericPrimitive->GetSupportedFormats();
44 THROW_IE_EXCEPTION << "External primitive doesn't have supported formats";
45 auto createAllDesc = [](const std::vector<MKLDNNEdgeWeakPtr> & edges,
46 const std::vector<InferenceEngine::MKLDNNPlugin::MemoryFormat>& formats) {
47 if (formats.size() != 1 || edges.size() < 2)
49 auto firstDims = edges[0].lock()->getDims();
50 for (size_t i = 1; i < edges.size(); i++) {
51 if (firstDims != edges[i].lock()->getDims())
56 for (auto &format : formats) {
58 bool isNotAny = false;
59 InferenceEngine::LayerConfig config;
60 config.dynBatchSupport = false;
61 bool isCompatible = true;
62 bool allDescCreate = createAllDesc(getParentEdges(), format.GetInputs());
63 for (size_t i = 0; i < getParentEdges().size(); i++) {
64 auto input_format = format.GetInputs()[0];
65 if (format.GetInputs().size() > i)
66 input_format = format.GetInputs()[i];
67 else if (!allDescCreate)
69 if (!MKLDNNMemory::isConsistant(getParentEdgeAt(i)->getDims(),
70 MKLDNNExtensionUtils::MemoryFormatToMKLFormat(input_format))) {
74 mkldnn::memory::format mkldnnFormat = MKLDNNExtensionUtils::MemoryFormatToMKLFormat(input_format);
75 InferenceEngine::DataConfig dataConf;
76 dataConf.inPlace = -1;
77 dataConf.constant = false;
78 dataConf.desc = MKLDNNMemoryDesc(getParentEdgeAt(i)->getDims(), inputDataType, mkldnnFormat);
79 config.inConfs.push_back(dataConf);
80 if (dataConf.desc.getLayout() == InferenceEngine::Layout::ANY) {
86 if (isAny && isNotAny) {
87 THROW_IE_EXCEPTION << "Layer " << getName() << " has incorrect input formats "
88 << " (any and not any formats don't supported in the same time).";
92 allDescCreate = createAllDesc(getChildEdges(), format.GetOutputs());
93 for (size_t i = 0; i < getChildEdges().size(); i++) {
94 auto output_format = format.GetOutputs()[0];
95 if (format.GetOutputs().size() > i)
96 output_format = format.GetOutputs()[i];
97 else if (!allDescCreate)
99 if (!MKLDNNMemory::isConsistant(getChildEdgeAt(i)->getDims(),
100 MKLDNNExtensionUtils::MemoryFormatToMKLFormat(
102 isCompatible = false;
105 mkldnn::memory::format mkldnnFormat = MKLDNNExtensionUtils::MemoryFormatToMKLFormat(output_format);
106 InferenceEngine::DataConfig dataConf;
107 dataConf.inPlace = -1;
108 dataConf.constant = false;
109 dataConf.desc = MKLDNNMemoryDesc(getChildEdgeAt(i)->getDims(), outputDataType, mkldnnFormat);
110 config.outConfs.push_back(dataConf);
111 if (dataConf.desc.getLayout() == InferenceEngine::Layout::ANY) {
117 if (isAny && isNotAny) {
118 THROW_IE_EXCEPTION << "Layer " << getName() << " has incorrect output formats "
119 << " (any and not any formats don't supported in the same time).";
122 supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
125 } else if (extFactory) {
126 InferenceEngine::ResponseDesc resp;
127 InferenceEngine::StatusCode rc = extFactory->getImplementations(impls, &resp);
128 if (rc != InferenceEngine::OK) {
129 THROW_IE_EXCEPTION << resp.msg;
131 for (auto &impl : impls) {
132 std::vector<InferenceEngine::LayerConfig> configs;
133 rc = impl->getSupportedConfigurations(configs, &resp);
134 if (rc != InferenceEngine::OK) {
135 THROW_IE_EXCEPTION << resp.msg;
138 for (auto& config : configs) {
139 supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
143 THROW_IE_EXCEPTION << "Layer " << getName() << " hasn't available configurations!";
146 THROW_IE_EXCEPTION << "Descriptor for generic primitive doesn't exist";
150 void MKLDNNGenericNode::createPrimitive() {
154 if (!genericPrimitive)
155 THROW_IE_EXCEPTION << "Descriptor for generic primitive doesn't exist";
156 if (getSelectedPrimitiveDescriptor() == nullptr)
157 THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set.";
160 void MKLDNNGenericNode::execute(mkldnn::stream strm) {
161 if (genericPrimitive) {
162 for (size_t i = 0; i < getParentEdges().size(); i++) {
163 auto& mklMemory = getParentEdgeAt(i)->getMemory();
164 inputs.push_back(MKLDNNExtensionUtils::MKLMemoryToGenericMemory(mklMemory));
167 for (size_t i = 0; i < getChildEdges().size(); i++) {
168 auto& mklMemory = getChildEdgeAt(i)->getMemory();
169 outputs.push_back(MKLDNNExtensionUtils::MKLMemoryToGenericMemory(mklMemory));
172 genericPrimitive->SetMemory(inputs, outputs);
173 genericPrimitive->Execute();
174 } else if (!impls.empty()) {
177 THROW_IE_EXCEPTION << "Descriptor for generic primitive doesn't exist";
181 bool MKLDNNGenericNode::created() const {
182 return Generic == getType();
185 bool MKLDNNGenericNode::created(const MKLDNNExtensionManager::Ptr &extMgr) {
186 if (getCnnLayer() && extMgr) {
187 // We should save extension manager in otder to avoid situation when
188 // it will destroyed before extensibility primitives
189 extensionManager = extMgr;
190 genericPrimitive.reset(extensionManager->CreateExtensionPrimitive(getCnnLayer()));
191 extFactory.reset(extensionManager->CreateExtensionFactory(getCnnLayer()));
193 if (genericPrimitive || extFactory)
199 void MKLDNNGenericNode::cleanup() {
200 MKLDNNNode::cleanup();
204 void MKLDNNGenericNode::execLayer() {
205 bool isDynBatch = dynBatchLim > 0;
206 std::vector<InferenceEngine::Blob::Ptr> inputs;
207 std::vector<InferenceEngine::TensorDesc> inputDescs;
208 std::vector<InferenceEngine::TensorDesc> outputDescs;
209 for (size_t i = 0; i < getParentEdges().size(); i++) {
210 inputs.push_back(getParentEdgeAt(i)->getBlob());
211 if (isDynBatch && dynBatchLim >= inputs[inputs.size() - 1]->getTensorDesc().getDims()[0]) {
214 // TODO: Ask the right dims using getShape() from previous node
215 inputDescs.push_back(inputs[inputs.size() - 1]->getTensorDesc());
216 inputDescs[inputDescs.size() - 1].getDims()[0] = static_cast<size_t>(batchToProcess());
221 auto sts = extFactory->getShapes(inputDescs, outputDescs, nullptr);
222 if (sts != InferenceEngine::StatusCode::OK)
227 for (size_t i = 0; i < inputs.size(); i++) {
228 auto td = inputs[i]->getTensorDesc();
229 td.setDims(inputDescs[i].getDims());
230 inputs[i] = make_blob_with_precision(td, getParentEdgeAt(i)->getMemory().GetData());
233 std::vector<InferenceEngine::Blob::Ptr> outputs;
234 for (size_t i = 0; i < getChildEdges().size(); i++) {
236 size_t idx = i >= outputDescs.size() ? 0 : i;
237 auto td = getChildEdgeAt(i)->getBlob()->getTensorDesc();
238 td.setDims(outputDescs[idx].getDims());
239 outputs.push_back(make_blob_with_precision(td, getChildEdgeAt(i)->getMemory().GetData()));
241 outputs.push_back(getChildEdgeAt(i)->getBlob());
244 auto * execImpl = dynamic_cast<InferenceEngine::ILayerExecImpl *>(impls[0].get());
245 if (execImpl != nullptr) {
246 InferenceEngine::ResponseDesc resp;
247 InferenceEngine::StatusCode rc = execImpl->execute(inputs, outputs, &resp);
248 if (rc != InferenceEngine::OK) {
249 THROW_IE_EXCEPTION << resp.msg;
254 MKLDNNGenericNode::~MKLDNNGenericNode() {
256 genericPrimitive.reset();
257 extensionManager.reset();
260 void MKLDNNGenericNode::initDescriptor(const InferenceEngine::LayerConfig &config) {
261 InferenceEngine::LayerConfig rightConfig = config;
262 if (genericPrimitive) {
263 for (auto &inConf : rightConfig.inConfs) {
264 inConf.constant = false;
266 if (inConf.desc.getLayout() == InferenceEngine::Layout::ANY) {
267 inConf.desc = InferenceEngine::TensorDesc(inConf.desc.getPrecision(),
268 inConf.desc.getDims(),
269 InferenceEngine::TensorDesc::getLayoutByDims(
270 inConf.desc.getDims()));
272 inConf.desc = InferenceEngine::TensorDesc(inConf.desc.getPrecision(),
273 inConf.desc.getDims(), {
274 inConf.desc.getBlockingDesc().getBlockDims(),
275 inConf.desc.getBlockingDesc().getOrder()
279 for (auto &outConf : rightConfig.outConfs) {
280 outConf.constant = false;
281 outConf.inPlace = -1;
282 if (outConf.desc.getLayout() == InferenceEngine::Layout::ANY) {
283 outConf.desc = InferenceEngine::TensorDesc(outConf.desc.getPrecision(),
284 outConf.desc.getDims(),
285 InferenceEngine::TensorDesc::getLayoutByDims(
286 outConf.desc.getDims()));
288 outConf.desc = InferenceEngine::TensorDesc(outConf.desc.getPrecision(),
289 outConf.desc.getDims(), {
290 outConf.desc.getBlockingDesc().getBlockDims(),
291 outConf.desc.getBlockingDesc().getOrder()
296 InferenceEngine::StatusCode rc;
297 InferenceEngine::ResponseDesc resp;
299 InferenceEngine::ILayerImpl::Ptr selectedImpl;
300 for (size_t k = 0, t = 0; k < impls.size(); k++) {
301 std::vector<InferenceEngine::LayerConfig> configs;
302 rc = impls[k]->getSupportedConfigurations(configs, &resp);
303 if (rc != InferenceEngine::OK) {
304 THROW_IE_EXCEPTION << resp.msg;
306 for (size_t j = 0; j < configs.size(); j++, t++) {
307 if (t == selectedPrimitiveDescriptorIndex) {
308 selectedImpl = impls[k];
313 for (size_t j = 0; j < rightConfig.inConfs.size(); j++) {
314 if (getParentEdgeAt(j)->getParent()->getChildEdges().size() > 1) {
315 rightConfig.inConfs[j].inPlace = -1;
318 for (auto &outConf : rightConfig.outConfs) {
319 if (outConf.inPlace < getParentEdges().size() &&
320 getParentEdgeAt(static_cast<size_t>(outConf.inPlace))->getParent()->getChildEdges().size() > 1) {
321 outConf.inPlace = -1;
327 impls.emplace_back(selectedImpl);
328 rc = impls[0]->init(rightConfig, &resp);
329 if (rc != InferenceEngine::OK) {
330 THROW_IE_EXCEPTION << resp.msg;
334 getSelectedPrimitiveDescriptor()->getConfig() = rightConfig;
335 bool isConst = !rightConfig.inConfs.empty() || !rightConfig.outConfs.empty();
336 for (const auto &inConf : rightConfig.inConfs) {
337 isConst = isConst && inConf.constant;
339 for (const auto &outConf : rightConfig.outConfs) {
340 isConst = isConst && outConf.constant;
343 constant = ConstantType::Const;
347 void MKLDNNGenericNode::initOptimalPrimitiveDescriptor() {
348 auto config = getSelectedPrimitiveDescriptor()->getConfig();
349 if (genericPrimitive) {
350 if (isInitConfig(config))
353 for (size_t i = 0; i < config.inConfs.size(); i++) {
354 if (!isUninitTensorDesc(config.inConfs[i].desc))
356 int num = getParentEdgeAt(i)->getInputNum();
357 if (getParentEdgeAt(i)->getParent()->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size() <= num)
359 auto parentConf = getParentEdgeAt(i)->getParent()->getSelectedPrimitiveDescriptor()->getConfig().outConfs[num];
360 if (getParentEdgeAt(i)->getParent()->getSelectedPrimitiveDescriptor()) {
362 if (isUninitTensorDesc(parentConf.desc) && (parentConf.inPlace >= 0 ||
363 parentConf.desc.getLayout() == InferenceEngine::Layout::ANY))
364 getParentEdgeAt(i)->getParent()->initOptimalPrimitiveDescriptor();
365 parentConf = getParentEdgeAt(i)->getParent()->getSelectedPrimitiveDescriptor()->getConfig().outConfs[num];
366 if (!isUninitTensorDesc(parentConf.desc) &&
367 MKLDNNExtensionUtils::initTensorsAreEqual(parentConf.desc, config.inConfs[i].desc)) {
368 if (config.inConfs[i].desc.getLayout() == InferenceEngine::Layout::ANY) {
369 for (size_t j = i + 1; j < config.inConfs.size(); j++) {
370 if (config.inConfs[j].desc.getLayout() == InferenceEngine::Layout::ANY) {
371 config.inConfs[j].desc = parentConf.desc;
374 for (auto &outConf : config.outConfs) {
375 if (outConf.desc.getLayout() == InferenceEngine::Layout::ANY) {
376 outConf.desc = parentConf.desc;
380 config.inConfs[i].desc = parentConf.desc;
385 if (config.inConfs[i].desc.getLayout() != InferenceEngine::Layout::ANY) {
386 config.inConfs[i].desc = InferenceEngine::TensorDesc(config.inConfs[i].desc.getPrecision(),
387 config.inConfs[i].desc.getDims(), {
388 config.inConfs[i].desc.getBlockingDesc().getBlockDims(),
389 config.inConfs[i].desc.getBlockingDesc().getOrder()
391 } else if (parentConf.desc.getLayout() != InferenceEngine::Layout::ANY) {
392 if (config.inConfs[i].desc.getLayout() == InferenceEngine::Layout::ANY) {
393 for (size_t j = i + 1; j < config.inConfs.size(); j++) {
394 if (config.inConfs[j].desc.getLayout() == InferenceEngine::Layout::ANY) {
395 config.inConfs[j].desc = InferenceEngine::TensorDesc(parentConf.desc.getPrecision(),
396 parentConf.desc.getDims(), {
397 parentConf.desc.getBlockingDesc().getBlockDims(),
398 parentConf.desc.getBlockingDesc().getOrder()
402 for (auto &outConf : config.outConfs) {
403 if (outConf.desc.getLayout() == InferenceEngine::Layout::ANY) {
404 outConf.desc = InferenceEngine::TensorDesc(parentConf.desc.getPrecision(),
405 parentConf.desc.getDims(), {
406 parentConf.desc.getBlockingDesc().getBlockDims(),
407 parentConf.desc.getBlockingDesc().getOrder()
412 config.inConfs[i].desc = InferenceEngine::TensorDesc(parentConf.desc.getPrecision(),
413 parentConf.desc.getDims(), {
414 parentConf.desc.getBlockingDesc().getBlockDims(),
415 parentConf.desc.getBlockingDesc().getOrder()
418 config.inConfs[i].desc = InferenceEngine::TensorDesc(config.inConfs[i].desc.getPrecision(),
419 config.inConfs[i].desc.getDims(),
420 InferenceEngine::TensorDesc::getLayoutByDims(config.inConfs[i].desc.getDims()));
424 for (size_t i = 0; i < config.outConfs.size(); i++) {
425 config.outConfs[i].desc = getConfiguredOutputDesc(config, i);
429 initDescriptor(config);