Removed shape infer extension (#917)
[platform/upstream/dldt.git] / inference-engine / src / mkldnn_plugin / mkldnn_extension_mngr.cpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vector>
6 #include <string>
7 #include <algorithm>
8
9 #include "mkldnn_extension_mngr.h"
10
11 using namespace MKLDNNPlugin;
12 using namespace InferenceEngine;
13
14 void MKLDNNExtensionManager::AddExtension(IExtensionPtr extension) {
15     _extensions.push_back(extension);
16 }
17
18 InferenceEngine::ILayerImpl::Ptr MKLDNNExtensionManager::CreateImplementation(const std::shared_ptr<ngraph::Node>& op) {
19     if (!op)
20         THROW_IE_EXCEPTION << "Cannot get nGraph operation!";
21     for (const auto& ext : _extensions) {
22         auto implTypes = ext->getImplTypes(op);
23         for (const auto& type : implTypes) {
24             if (type != "CPU")
25                 continue;
26             auto impl = ext->getImplementation(op, "CPU");
27             if (impl)
28                 return impl;
29         }
30     }
31     return nullptr;
32 }
33
34 IE_SUPPRESS_DEPRECATED_START
35
36 std::shared_ptr<InferenceEngine::ILayerImplFactory> MKLDNNExtensionManager::CreateExtensionFactory(
37         const InferenceEngine::CNNLayerPtr &layer) {
38     if (!layer)
39         THROW_IE_EXCEPTION << "Cannot get cnn layer!";
40     std::shared_ptr<ILayerImplFactory> factory;
41     for (auto& ext : _extensions) {
42         ResponseDesc responseDesc;
43         StatusCode rc;
44         ILayerImplFactory* factory_ptr = nullptr;
45         rc = ext->getFactoryFor(factory_ptr, layer.get(), &responseDesc);
46         if (rc != OK) {
47             factory = nullptr;
48             continue;
49         } else {
50             factory.reset(factory_ptr);
51         }
52         if (factory) {
53             break;
54         }
55     }
56     return factory;
57 }
58
59 IE_SUPPRESS_DEPRECATED_END