Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_plugin.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "mkldnn_plugin.h"
6 #include "mkldnn_extension_mngr.h"
7 #include <cpp_interfaces/base/ie_plugin_base.hpp>
8 #include <memory>
9
10 using namespace MKLDNNPlugin;
11 using namespace InferenceEngine;
12
13 MKLDNNWeightsSharing Engine::weightsSharing;
14 const SimpleDataHash MKLDNNWeightsSharing::simpleCRC;
15
16 InferenceEngine::ExecutableNetworkInternal::Ptr
17 Engine::LoadExeNetworkImpl(InferenceEngine::ICNNNetwork &network, const std::map<std::string, std::string> &config) {
18     auto specifiedDevice = network.getTargetDevice();
19     auto supportedDevice = InferenceEngine::TargetDevice::eCPU;
20     if (specifiedDevice != InferenceEngine::TargetDevice::eDefault && specifiedDevice != supportedDevice) {
21         THROW_IE_EXCEPTION << "The plugin doesn't support target device: " << getDeviceName(specifiedDevice) << ".\n" <<
22                            "Supported target device: " << getDeviceName(supportedDevice);
23     }
24
25     // verification of supported input
26     InferenceEngine::InputsDataMap _networkInputs;
27     network.getInputsInfo(_networkInputs);
28     for (auto ii : _networkInputs) {
29         auto input_precision = ii.second->getInputPrecision();
30         if (input_precision != InferenceEngine::Precision::FP32 &&
31             input_precision != InferenceEngine::Precision::I32 &&
32             input_precision != InferenceEngine::Precision::U16 &&
33             input_precision != InferenceEngine::Precision::I16 &&
34             input_precision != InferenceEngine::Precision::I8 &&
35             input_precision != InferenceEngine::Precision::U8) {
36             THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str
37                                << "Input image format " << input_precision << " is not supported yet...";
38         }
39     }
40
41     // TODO: handle input precision differently - per input and not one per network...
42
43     // TODO: Clarify the behavior of SetConfig method. Skip eng_config or not?
44     Config conf = engConfig;
45     conf.readProperties(config);
46
47     if (conf.enableDynamicBatch) {
48         conf.batchLimit = network.getBatchSize();
49     }
50
51     return std::make_shared<MKLDNNExecNetwork>(network, conf, extensionManager);
52 }
53
54 void Engine::SetConfig(const std::map<std::string, std::string> &config) {
55     // accumulate config parameters on engine level
56     engConfig.readProperties(config);
57
58     // Pass config to already loaded network
59     // TODO: Clarify the behavior of SetConfig method. Should it pass data to already loaded networks?
60     if (_loadedNetwork) {
61         // ugly casting. can we avoid it?
62         auto exe_network =
63                 dynamic_cast<ExecutableNetworkBase<ExecutableNetworkInternal>*>(_loadedNetwork.get());
64         auto exe_network_impl = dynamic_cast<MKLDNNExecNetwork*>(exe_network->getImpl().get());
65
66         exe_network_impl->setProperty(config);
67     }
68 }
69
70 void Engine::AddExtension(InferenceEngine::IExtensionPtr extension) {
71     extensionManager->AddExtension(extension);
72 }
73
74 void Engine::QueryNetwork(const ICNNNetwork& network, QueryNetworkResult& res) const {
75     QueryNetwork(network, {}, res);
76 }
77
78 void Engine::QueryNetwork(const ICNNNetwork& network, const std::map<std::string, std::string>& config, QueryNetworkResult& res) const {
79     details::CNNNetworkIterator i(const_cast<ICNNNetwork *>(&network));
80     while (i != details::CNNNetworkIterator()) {
81         try {
82             mkldnn::engine eng(mkldnn::engine(mkldnn::engine::kind::cpu, 0));
83             // if we can create and have not thrown exception, then layer is supported
84             std::unique_ptr <MKLDNNNode>(MKLDNNNode::CreateNode(*i, eng, extensionManager));
85             res.supportedLayers.insert((*i)->name);
86         } catch (InferenceEngine::details::InferenceEngineException&) {
87         }
88         i++;
89     }
90 }
91
92 INFERENCE_PLUGIN_API(StatusCode) CreatePluginEngine(IInferencePlugin*& plugin, ResponseDesc *resp) noexcept {
93     try {
94         plugin = make_ie_compatible_plugin(
95                 {{1, 6},
96                  CI_BUILD_NUMBER,
97                  "MKLDNNPlugin"}, std::make_shared<Engine>());
98         return OK;
99     }
100     catch (std::exception &ex) {
101         return DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
102     }
103 }