Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / extensions / fake_layer.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <extension/ext_list.hpp>
6 #include <extension/ext_base.cpp>
7
8 #include <string>
9 #include <map>
10 #include <memory>
11 #include <algorithm>
12
13 using namespace InferenceEngine;
14 using namespace Extensions;
15
16 struct TestExtensionsHolder {
17     std::map<std::string, Cpu::ext_factory> list;
18     std::map<std::string, IShapeInferImpl::Ptr> si_list;
19 };
20
21
22 class FakeExtensions : public IExtension {
23  public:
24
25     void SetLogCallback(InferenceEngine::IErrorListener &listener) noexcept override {};
26
27     void Unload() noexcept override {};
28
29     void Release() noexcept override {
30         delete this;
31     };
32
33     static std::shared_ptr<TestExtensionsHolder> GetExtensionsHolder() {
34         static std::shared_ptr<TestExtensionsHolder> localHolder;
35         if (localHolder == nullptr) {
36             localHolder = std::shared_ptr<TestExtensionsHolder>(new TestExtensionsHolder());
37         }
38         return localHolder;
39     }
40
41     static void AddExt(std::string name, Cpu::ext_factory factory) {
42         GetExtensionsHolder()->list[name] = factory;
43     }
44
45     void GetVersion(const Version *&versionInfo) const noexcept override {
46         static Version ExtensionDescription = {
47             {1, 6},    // extension API version
48             "1.6",
49             "ie-cpu-ext"  // extension description message
50         };
51
52         versionInfo = &ExtensionDescription;
53     }
54
55     StatusCode getPrimitiveTypes(char **&types, unsigned int &size, ResponseDesc *resp) noexcept override {
56         collectTypes(types, size, GetExtensionsHolder()->list);
57         return OK;
58     };
59     StatusCode getFactoryFor(ILayerImplFactory *&factory, const CNNLayer *cnnLayer, ResponseDesc *resp) noexcept override {
60         auto &factories = GetExtensionsHolder()->list;
61         if (factories.find(cnnLayer->type) == factories.end()) {
62             std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
63             errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
64             return NOT_FOUND;
65         }
66         factory = factories[cnnLayer->type](cnnLayer);
67         return OK;
68     }
69     StatusCode getShapeInferTypes(char **&types, unsigned int &size, ResponseDesc *resp) noexcept override {
70         collectTypes(types, size, GetExtensionsHolder()->si_list);
71         return OK;
72     };
73
74     StatusCode getShapeInferImpl(IShapeInferImpl::Ptr &impl, const char *type, ResponseDesc *resp) noexcept override {
75         auto &factories = GetExtensionsHolder()->si_list;
76         if (factories.find(type) == factories.end()) {
77             std::string errorMsg = std::string("Shape Infer Implementation for ") + type + " wasn't found!";
78             if (resp) errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
79             return NOT_FOUND;
80         }
81         impl = factories[type];
82         return OK;
83     }
84
85     template<class T>
86     void collectTypes(char **&types, unsigned int &size, const std::map<std::string, T> &factories) {
87         types = new char *[factories.size()];
88         unsigned count = 0;
89         for (auto it = factories.begin(); it != factories.end(); it++, count++) {
90             types[count] = new char[it->first.size() + 1];
91             std::copy(it->first.begin(), it->first.end(), types[count]);
92             types[count][it->first.size()] = '\0';
93         }
94         size = count;
95     }
96 };
97
98  class FakeLayerPLNImpl: public Cpu::ExtLayerBase {
99 public:
100     explicit FakeLayerPLNImpl(const CNNLayer* layer) {
101         try {
102             addConfig(layer, {{ConfLayout::PLN, false, 0}}, {{ConfLayout::PLN, false, 0}});
103         } catch (InferenceEngine::details::InferenceEngineException &ex) {
104             errorMsg = ex.what();
105         }
106     }
107
108     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
109                        ResponseDesc *resp) noexcept override {
110         return OK;
111     }
112 };
113
114 class FakeLayerBLKImpl: public Cpu::ExtLayerBase {
115 public:
116     explicit FakeLayerBLKImpl(const CNNLayer* layer) {
117         try {
118 #if defined(HAVE_AVX512F)
119             auto blk_layout = ConfLayout::BLK16;
120 #else
121             auto blk_layout = ConfLayout::BLK8;
122 #endif
123             addConfig(layer, {{blk_layout, false, 0}}, {{blk_layout, false, 0}});
124         } catch (InferenceEngine::details::InferenceEngineException &ex) {
125             errorMsg = ex.what();
126         }
127     }
128
129     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
130                        ResponseDesc *resp) noexcept override {
131         return OK;
132     }
133 };
134
135 template<typename Ext>
136 class FakeRegisterBase {
137  public:
138     explicit FakeRegisterBase(const std::string& type) {
139         FakeExtensions::AddExt(type,
140                               [](const CNNLayer* layer) -> InferenceEngine::ILayerImplFactory* {
141                                   return new Ext(layer);
142                               });
143     }
144 };
145
146 #define REG_FAKE_FACTORY_FOR(__prim, __type) \
147 static FakeRegisterBase<__prim> __reg__##__type(#__type)
148
149 REG_FAKE_FACTORY_FOR(Cpu::ImplFactory<FakeLayerPLNImpl>, FakeLayerPLN);
150 REG_FAKE_FACTORY_FOR(Cpu::ImplFactory<FakeLayerBLKImpl>, FakeLayerBLK);
151
152
153 InferenceEngine::IExtensionPtr make_FakeExtensions() {
154     return InferenceEngine::IExtensionPtr(new FakeExtensions());
155 }