1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <extension/ext_list.hpp>
6 #include <extension/ext_base.cpp>
13 using namespace InferenceEngine;
14 using namespace Extensions;
16 struct TestExtensionsHolder {
17 std::map<std::string, Cpu::ext_factory> list;
18 std::map<std::string, IShapeInferImpl::Ptr> si_list;
22 class FakeExtensions : public IExtension {
25 void SetLogCallback(InferenceEngine::IErrorListener &listener) noexcept override {};
27 void Unload() noexcept override {};
29 void Release() noexcept override {
33 static std::shared_ptr<TestExtensionsHolder> GetExtensionsHolder() {
34 static std::shared_ptr<TestExtensionsHolder> localHolder;
35 if (localHolder == nullptr) {
36 localHolder = std::shared_ptr<TestExtensionsHolder>(new TestExtensionsHolder());
41 static void AddExt(std::string name, Cpu::ext_factory factory) {
42 GetExtensionsHolder()->list[name] = factory;
45 void GetVersion(const Version *&versionInfo) const noexcept override {
46 static Version ExtensionDescription = {
47 {1, 6}, // extension API version
49 "ie-cpu-ext" // extension description message
52 versionInfo = &ExtensionDescription;
55 StatusCode getPrimitiveTypes(char **&types, unsigned int &size, ResponseDesc *resp) noexcept override {
56 collectTypes(types, size, GetExtensionsHolder()->list);
59 StatusCode getFactoryFor(ILayerImplFactory *&factory, const CNNLayer *cnnLayer, ResponseDesc *resp) noexcept override {
60 auto &factories = GetExtensionsHolder()->list;
61 if (factories.find(cnnLayer->type) == factories.end()) {
62 std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
63 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
66 factory = factories[cnnLayer->type](cnnLayer);
69 StatusCode getShapeInferTypes(char **&types, unsigned int &size, ResponseDesc *resp) noexcept override {
70 collectTypes(types, size, GetExtensionsHolder()->si_list);
74 StatusCode getShapeInferImpl(IShapeInferImpl::Ptr &impl, const char *type, ResponseDesc *resp) noexcept override {
75 auto &factories = GetExtensionsHolder()->si_list;
76 if (factories.find(type) == factories.end()) {
77 std::string errorMsg = std::string("Shape Infer Implementation for ") + type + " wasn't found!";
78 if (resp) errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
81 impl = factories[type];
86 void collectTypes(char **&types, unsigned int &size, const std::map<std::string, T> &factories) {
87 types = new char *[factories.size()];
89 for (auto it = factories.begin(); it != factories.end(); it++, count++) {
90 types[count] = new char[it->first.size() + 1];
91 std::copy(it->first.begin(), it->first.end(), types[count]);
92 types[count][it->first.size()] = '\0';
98 class FakeLayerPLNImpl: public Cpu::ExtLayerBase {
100 explicit FakeLayerPLNImpl(const CNNLayer* layer) {
102 addConfig(layer, {{ConfLayout::PLN, false, 0}}, {{ConfLayout::PLN, false, 0}});
103 } catch (InferenceEngine::details::InferenceEngineException &ex) {
104 errorMsg = ex.what();
108 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
109 ResponseDesc *resp) noexcept override {
114 class FakeLayerBLKImpl: public Cpu::ExtLayerBase {
116 explicit FakeLayerBLKImpl(const CNNLayer* layer) {
118 #if defined(HAVE_AVX512F)
119 auto blk_layout = ConfLayout::BLK16;
121 auto blk_layout = ConfLayout::BLK8;
123 addConfig(layer, {{blk_layout, false, 0}}, {{blk_layout, false, 0}});
124 } catch (InferenceEngine::details::InferenceEngineException &ex) {
125 errorMsg = ex.what();
129 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
130 ResponseDesc *resp) noexcept override {
135 template<typename Ext>
136 class FakeRegisterBase {
138 explicit FakeRegisterBase(const std::string& type) {
139 FakeExtensions::AddExt(type,
140 [](const CNNLayer* layer) -> InferenceEngine::ILayerImplFactory* {
141 return new Ext(layer);
146 #define REG_FAKE_FACTORY_FOR(__prim, __type) \
147 static FakeRegisterBase<__prim> __reg__##__type(#__type)
149 REG_FAKE_FACTORY_FOR(Cpu::ImplFactory<FakeLayerPLNImpl>, FakeLayerPLN);
150 REG_FAKE_FACTORY_FOR(Cpu::ImplFactory<FakeLayerBLKImpl>, FakeLayerBLK);
153 InferenceEngine::IExtensionPtr make_FakeExtensions() {
154 return InferenceEngine::IExtensionPtr(new FakeExtensions());