Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_list.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6
7 #include <string>
8 #include <map>
9 #include <memory>
10 #include <algorithm>
11
12 namespace InferenceEngine {
13 namespace Extensions {
14 namespace Cpu {
15
16 std::shared_ptr<ExtensionsHolder> CpuExtensions::GetExtensionsHolder() {
17     static std::shared_ptr<ExtensionsHolder> localHolder;
18     if (localHolder == nullptr) {
19         localHolder = std::shared_ptr<ExtensionsHolder>(new ExtensionsHolder());
20     }
21     return localHolder;
22 }
23
24 void CpuExtensions::AddExt(std::string name, ext_factory factory) {
25     GetExtensionsHolder()->list[name] = factory;
26 }
27
28 void CpuExtensions::AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl) {
29     GetExtensionsHolder()->si_list[name] = impl;
30 }
31
32 void CpuExtensions::GetVersion(const Version*& versionInfo) const noexcept {
33     static Version ExtensionDescription = {
34             { 1, 6 },    // extension API version
35             "1.6",
36             "ie-cpu-ext"  // extension description message
37     };
38
39     versionInfo = &ExtensionDescription;
40 }
41
42 StatusCode CpuExtensions::getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
43     collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->list);
44     return OK;
45 };
46 StatusCode CpuExtensions::getFactoryFor(ILayerImplFactory *&factory, const CNNLayer *cnnLayer, ResponseDesc *resp) noexcept {
47     auto& factories = CpuExtensions::GetExtensionsHolder()->list;
48     if (factories.find(cnnLayer->type) == factories.end()) {
49         std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
50         errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
51         return NOT_FOUND;
52     }
53     factory = factories[cnnLayer->type](cnnLayer);
54     return OK;
55 }
56 StatusCode CpuExtensions::getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
57     collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->si_list);
58     return OK;
59 };
60
61 StatusCode CpuExtensions::getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept {
62     auto& factories = CpuExtensions::GetExtensionsHolder()->si_list;
63     if (factories.find(type) == factories.end()) {
64         std::string errorMsg = std::string("Shape Infer Implementation for ") + type + " wasn't found!";
65         if (resp) errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
66         return NOT_FOUND;
67     }
68     impl = factories[type];
69     return OK;
70 }
71
72 template<class T>
73 void CpuExtensions::collectTypes(char**& types, unsigned int& size, const std::map<std::string, T>& factories) {
74     types = new char *[factories.size()];
75     unsigned count = 0;
76     for (auto it = factories.begin(); it != factories.end(); it++, count ++) {
77         types[count] = new char[it->first.size() + 1];
78         std::copy(it->first.begin(), it->first.end(), types[count]);
79         types[count][it->first.size() ] = '\0';
80     }
81     size = count;
82 }
83
84 }  // namespace Cpu
85 }  // namespace Extensions
86
87
88 // Exported function
89 INFERENCE_EXTENSION_API(StatusCode) CreateExtension(IExtension*& ext, ResponseDesc* resp) noexcept {
90     try {
91         ext = new Extensions::Cpu::CpuExtensions();
92         return OK;
93     } catch (std::exception& ex) {
94         if (resp) {
95             std::string err = ((std::string)"Couldn't create extension: ") + ex.what();
96             err.copy(resp->msg, 255);
97         }
98         return GENERAL_ERROR;
99     }
100 }
101
102 // Exported function
103 INFERENCE_EXTENSION_API(StatusCode) CreateShapeInferExtension(IShapeInferExtension*& ext, ResponseDesc* resp) noexcept {
104     IExtension * pExt = nullptr;
105     StatusCode  result = CreateExtension(pExt, resp);
106     if (result == OK) {
107         ext = pExt;
108     }
109
110     return result;
111 }
112
113 }  // namespace InferenceEngine