1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <mkldnn_extension_mngr.h>
6 #include <mkldnn_extension_utils.h>
7 #include "mkldnn_generic_node.h"
10 #include <blob_factory.hpp>
12 using namespace mkldnn;
13 using namespace MKLDNNPlugin;
15 MKLDNNGenericNode::MKLDNNGenericNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng) : MKLDNNNode(layer, eng) {}
17 void MKLDNNGenericNode::getSupportedDescriptors() {
19 std::string type = getCnnLayer() ? getCnnLayer()->type : "Generic";
20 THROW_IE_EXCEPTION << "Cannot get generic primitive for layer: " << getName() << " with type: " << type;
24 void MKLDNNGenericNode::initSupportedPrimitiveDescriptors() {
25 if (!supportedPrimitiveDescriptors.empty())
28 InferenceEngine::Precision precision = getCnnLayer()->insData[0].lock()->getPrecision();
29 if (precision != InferenceEngine::Precision::FP32)
30 precision = InferenceEngine::Precision::FP32;
31 auto inputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
32 precision = getCnnLayer()->outData[0]->getPrecision();
33 if (precision != InferenceEngine::Precision::FP32)
34 precision = InferenceEngine::Precision::FP32;
35 auto outputDataType = MKLDNNExtensionUtils::IEPrecisionToDataType(precision);
38 THROW_IE_EXCEPTION << "Descriptor for generic primitive doesn't exist";
40 InferenceEngine::ResponseDesc resp;
41 InferenceEngine::StatusCode rc = extFactory->getImplementations(impls, &resp);
42 if (rc != InferenceEngine::OK) {
43 THROW_IE_EXCEPTION << resp.msg;
45 for (auto &impl : impls) {
46 std::vector<InferenceEngine::LayerConfig> configs;
47 rc = impl->getSupportedConfigurations(configs, &resp);
48 if (rc != InferenceEngine::OK) {
49 THROW_IE_EXCEPTION << resp.msg;
52 for (auto& config : configs) {
53 supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
57 THROW_IE_EXCEPTION << "Layer " << getName() << " hasn't available configurations!";
61 void MKLDNNGenericNode::createPrimitive() {
65 if (getSelectedPrimitiveDescriptor() == nullptr)
66 THROW_IE_EXCEPTION << "Preferable primitive descriptor does not set.";
69 void MKLDNNGenericNode::execute(mkldnn::stream strm) {
73 THROW_IE_EXCEPTION << "Descriptor for generic primitive doesn't exist";
77 bool MKLDNNGenericNode::created() const {
78 return Generic == getType();
81 bool MKLDNNGenericNode::created(const MKLDNNExtensionManager::Ptr &extMgr) {
82 if (getCnnLayer() && extMgr) {
83 // We should save extension manager in otder to avoid situation when
84 // it will destroyed before extensibility primitives
85 extFactory.reset(extMgr->CreateExtensionFactory(getCnnLayer()));
93 void MKLDNNGenericNode::cleanup() {
94 MKLDNNNode::cleanup();
98 void MKLDNNGenericNode::execLayer() {
99 bool isDynBatch = dynBatchLim > 0;
100 std::vector<InferenceEngine::Blob::Ptr> inputs;
101 std::vector<InferenceEngine::TensorDesc> inputDescs;
102 std::vector<InferenceEngine::TensorDesc> outputDescs;
103 for (size_t i = 0; i < getParentEdges().size(); i++) {
104 inputs.push_back(getParentEdgeAt(i)->getBlob());
105 if (isDynBatch && dynBatchLim >= inputs[inputs.size() - 1]->getTensorDesc().getDims()[0]) {
108 // TODO: Ask the right dims using getShape() from previous node
109 inputDescs.push_back(inputs[inputs.size() - 1]->getTensorDesc());
110 inputDescs[inputDescs.size() - 1].getDims()[0] = static_cast<size_t>(batchToProcess());
115 auto sts = extFactory->getShapes(inputDescs, outputDescs, nullptr);
116 if (sts != InferenceEngine::StatusCode::OK)
121 for (size_t i = 0; i < inputs.size(); i++) {
122 auto td = inputs[i]->getTensorDesc();
123 td.setDims(inputDescs[i].getDims());
124 inputs[i] = make_blob_with_precision(td, getParentEdgeAt(i)->getMemory().GetData());
127 std::vector<InferenceEngine::Blob::Ptr> outputs;
128 for (size_t i = 0; i < getChildEdges().size(); i++) {
130 size_t idx = i >= outputDescs.size() ? 0 : i;
131 auto td = getChildEdgeAt(i)->getBlob()->getTensorDesc();
132 td.setDims(outputDescs[idx].getDims());
133 outputs.push_back(make_blob_with_precision(td, getChildEdgeAt(i)->getMemory().GetData()));
135 outputs.push_back(getChildEdgeAt(i)->getBlob());
138 auto * execImpl = dynamic_cast<InferenceEngine::ILayerExecImpl *>(impls[0].get());
139 if (execImpl != nullptr) {
140 InferenceEngine::ResponseDesc resp;
141 InferenceEngine::StatusCode rc = execImpl->execute(inputs, outputs, &resp);
142 if (rc != InferenceEngine::OK) {
143 THROW_IE_EXCEPTION << resp.msg;
148 void MKLDNNGenericNode::initDescriptor(const InferenceEngine::LayerConfig &config) {
149 InferenceEngine::LayerConfig rightConfig = config;
150 InferenceEngine::StatusCode rc;
151 InferenceEngine::ResponseDesc resp;
153 InferenceEngine::ILayerImpl::Ptr selectedImpl;
154 for (size_t k = 0, t = 0; k < impls.size(); k++) {
155 std::vector<InferenceEngine::LayerConfig> configs;
156 rc = impls[k]->getSupportedConfigurations(configs, &resp);
157 if (rc != InferenceEngine::OK) {
158 THROW_IE_EXCEPTION << resp.msg;
160 for (size_t j = 0; j < configs.size(); j++, t++) {
161 if (t == selectedPrimitiveDescriptorIndex) {
162 selectedImpl = impls[k];
167 for (size_t j = 0; j < rightConfig.inConfs.size(); j++) {
168 if (getParentEdgeAt(j)->getParent()->getChildEdges().size() > 1) {
169 rightConfig.inConfs[j].inPlace = -1;
172 for (auto &outConf : rightConfig.outConfs) {
173 if (outConf.inPlace < getParentEdges().size() &&
174 getParentEdgeAt(static_cast<size_t>(outConf.inPlace))->getParent()->getChildEdges().size() > 1) {
175 outConf.inPlace = -1;
181 impls.emplace_back(selectedImpl);
182 rc = impls[0]->init(rightConfig, &resp);
183 if (rc != InferenceEngine::OK) {
184 THROW_IE_EXCEPTION << resp.msg;
187 auto descriptor = getSelectedPrimitiveDescriptor();
188 if (descriptor != nullptr) {
189 descriptor->getConfig() = rightConfig;
191 bool isConst = !rightConfig.inConfs.empty() || !rightConfig.outConfs.empty();
192 for (const auto &inConf : rightConfig.inConfs) {
193 isConst = isConst && inConf.constant;
195 for (const auto &outConf : rightConfig.outConfs) {
196 isConst = isConst && outConf.constant;
199 constant = ConstantType::Const;