1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
12 namespace InferenceEngine {
13 namespace Extensions {
16 std::shared_ptr<ExtensionsHolder> CpuExtensions::GetExtensionsHolder() {
17 static std::shared_ptr<ExtensionsHolder> localHolder;
18 if (localHolder == nullptr) {
19 localHolder = std::shared_ptr<ExtensionsHolder>(new ExtensionsHolder());
24 void CpuExtensions::AddExt(std::string name, ext_factory factory) {
25 GetExtensionsHolder()->list[name] = factory;
28 void CpuExtensions::AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl) {
29 GetExtensionsHolder()->si_list[name] = impl;
32 void CpuExtensions::GetVersion(const Version*& versionInfo) const noexcept {
33 static Version ExtensionDescription = {
34 { 1, 6 }, // extension API version
36 "ie-cpu-ext" // extension description message
39 versionInfo = &ExtensionDescription;
42 StatusCode CpuExtensions::getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
43 collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->list);
46 StatusCode CpuExtensions::getFactoryFor(ILayerImplFactory *&factory, const CNNLayer *cnnLayer, ResponseDesc *resp) noexcept {
47 auto& factories = CpuExtensions::GetExtensionsHolder()->list;
48 if (factories.find(cnnLayer->type) == factories.end()) {
49 std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
50 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
53 factory = factories[cnnLayer->type](cnnLayer);
56 StatusCode CpuExtensions::getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
57 collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->si_list);
61 StatusCode CpuExtensions::getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept {
62 auto& factories = CpuExtensions::GetExtensionsHolder()->si_list;
63 if (factories.find(type) == factories.end()) {
64 std::string errorMsg = std::string("Shape Infer Implementation for ") + type + " wasn't found!";
65 if (resp) errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
68 impl = factories[type];
73 void CpuExtensions::collectTypes(char**& types, unsigned int& size, const std::map<std::string, T>& factories) {
74 types = new char *[factories.size()];
76 for (auto it = factories.begin(); it != factories.end(); it++, count ++) {
77 types[count] = new char[it->first.size() + 1];
78 std::copy(it->first.begin(), it->first.end(), types[count]);
79 types[count][it->first.size() ] = '\0';
85 } // namespace Extensions
89 INFERENCE_EXTENSION_API(StatusCode) CreateExtension(IExtension*& ext, ResponseDesc* resp) noexcept {
91 ext = new Extensions::Cpu::CpuExtensions();
93 } catch (std::exception& ex) {
95 std::string err = ((std::string)"Couldn't create extension: ") + ex.what();
96 err.copy(resp->msg, 255);
103 INFERENCE_EXTENSION_API(StatusCode) CreateShapeInferExtension(IShapeInferExtension*& ext, ResponseDesc* resp) noexcept {
104 IExtension * pExt = nullptr;
105 StatusCode result = CreateExtension(pExt, resp);
113 } // namespace InferenceEngine