1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include "ext_list.hpp"
13 namespace InferenceEngine {
14 namespace Extensions {
17 std::shared_ptr<ExtensionsHolder> CpuExtensions::GetExtensionsHolder() {
18 static std::shared_ptr<ExtensionsHolder> localHolder;
19 if (localHolder == nullptr) {
20 localHolder = std::shared_ptr<ExtensionsHolder>(new ExtensionsHolder());
25 void CpuExtensions::AddExt(std::string name, ext_factory factory) {
26 GetExtensionsHolder()->list[name] = factory;
29 void CpuExtensions::AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl) {
30 GetExtensionsHolder()->si_list[name] = impl;
33 void CpuExtensions::GetVersion(const Version*& versionInfo) const noexcept {
34 static Version ExtensionDescription = {
35 { 1, 0 }, // extension API version
37 "ie-cpu-ext" // extension description message
40 versionInfo = &ExtensionDescription;
43 StatusCode CpuExtensions::getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
44 collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->list);
47 StatusCode CpuExtensions::getFactoryFor(ILayerImplFactory *&factory, const CNNLayer *cnnLayer, ResponseDesc *resp) noexcept {
48 auto& factories = CpuExtensions::GetExtensionsHolder()->list;
49 if (factories.find(cnnLayer->type) == factories.end()) {
50 std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
51 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
54 factory = factories[cnnLayer->type](cnnLayer);
57 StatusCode CpuExtensions::getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept {
58 collectTypes(types, size, CpuExtensions::GetExtensionsHolder()->si_list);
62 StatusCode CpuExtensions::getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept {
63 auto& factories = CpuExtensions::GetExtensionsHolder()->si_list;
64 if (factories.find(type) == factories.end()) {
65 std::string errorMsg = std::string("Shape Infer Implementation for ") + type + " wasn't found!";
66 if (resp) errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
69 impl = factories[type];
74 void CpuExtensions::collectTypes(char**& types, unsigned int& size, const std::map<std::string, T>& factories) {
75 types = new char *[factories.size()];
77 for (auto it = factories.begin(); it != factories.end(); it++, count ++) {
78 types[count] = new char[it->first.size() + 1];
79 std::copy(it->first.begin(), it->first.end(), types[count]);
80 types[count][it->first.size() ] = '\0';
87 INFERENCE_EXTENSION_API(StatusCode) CreateExtension(IExtension*& ext, ResponseDesc* resp) noexcept {
89 ext = new CpuExtensions();
91 } catch (std::exception& ex) {
93 std::string err = ((std::string)"Couldn't create extension: ") + ex.what();
94 err.copy(resp->msg, 255);
101 } // namespace Extensions
102 } // namespace InferenceEngine