456943786c5017d5218c57d887d5f9e5a661bdf1
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_list.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "ext_list.hpp"
7
8 #include <string>
9 #include <map>
10 #include <memory>
11 #include <algorithm>
12
13 namespace InferenceEngine {
14 namespace Extensions {
15 namespace Cpu {
16
17 std::shared_ptr<ExtensionsHolder> CpuExtensions::GetExtensionsHolder() {
18     static std::shared_ptr<ExtensionsHolder> localHolder;
19     if (localHolder == nullptr) {
20         localHolder = std::shared_ptr<ExtensionsHolder>(new ExtensionsHolder());
21     }
22     return localHolder;
23 }
24
25 void CpuExtensions::AddExt(std::string name, ext_factory factory) {
26     GetExtensionsHolder()->list[name] = factory;
27 }
28
29 void CpuExtensions::AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl) {
30     GetExtensionsHolder()->si_list[name] = impl;
31 }
32
33 void CpuExtensions::GetVersion(const Version*& versionInfo) const noexcept {
34     static Version ExtensionDescription = {
35             { 1, 0 },    // extension API version
36             "1.0",
37             "ie-cpu-ext"  // extension description message
38     };
39
40     versionInfo = &ExtensionDescription;
41 }
42
43 StatusCode CpuExtensions::getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
44     collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->list);
45     return OK;
46 };
47 StatusCode CpuExtensions::getFactoryFor(ILayerImplFactory *&factory, const CNNLayer *cnnLayer, ResponseDesc *resp) noexcept {
48     auto& factories = CpuExtensions::GetExtensionsHolder()->list;
49     if (factories.find(cnnLayer->type) == factories.end()) {
50         std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
51         errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
52         return NOT_FOUND;
53     }
54     factory = factories[cnnLayer->type](cnnLayer);
55     return OK;
56 }
57 StatusCode CpuExtensions::getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
58     collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->si_list);
59     return OK;
60 };
61
62 StatusCode CpuExtensions::getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept {
63     auto& factories = CpuExtensions::GetExtensionsHolder()->si_list;
64     if (factories.find(type) == factories.end()) {
65         std::string errorMsg = std::string("Shape Infer Implementation for ") + type + " wasn't found!";
66         if (resp) errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
67         return NOT_FOUND;
68     }
69     impl = factories[type];
70     return OK;
71 }
72
73 template<class T>
74 void CpuExtensions::collectTypes(char**& types, unsigned int& size, const std::map<std::string, T>& factories) {
75     types = new char *[factories.size()];
76     unsigned count = 0;
77     for (auto it = factories.begin(); it != factories.end(); it++, count ++) {
78         types[count] = new char[it->first.size() + 1];
79         std::copy(it->first.begin(), it->first.end(), types[count]);
80         types[count][it->first.size() ] = '\0';
81     }
82     size = count;
83 }
84
85
86 // Exported function
87 INFERENCE_EXTENSION_API(StatusCode) CreateExtension(IExtension*& ext, ResponseDesc* resp) noexcept {
88     try {
89         ext = new CpuExtensions();
90         return OK;
91     } catch (std::exception& ex) {
92         if (resp) {
93             std::string err = ((std::string)"Couldn't create extension: ") + ex.what();
94             err.copy(resp->msg, 255);
95         }
96         return GENERAL_ERROR;
97     }
98 }
99
100 }  // namespace Cpu
101 }  // namespace Extensions
102 }  // namespace InferenceEngine
103