1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <ie_iextension.h>
14 namespace InferenceEngine {
15 namespace Extensions {
18 using ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer*)>;
20 struct ExtensionsHolder {
21 std::map<std::string, ext_factory> list;
22 std::map<std::string, IShapeInferImpl::Ptr> si_list;
25 class INFERENCE_ENGINE_API_CLASS(CpuExtensions) : public IExtension {
27 StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override;
30 getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer, ResponseDesc* resp) noexcept override;
32 StatusCode getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override;
34 StatusCode getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept override;
36 void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override;
38 void SetLogCallback(InferenceEngine::IErrorListener& /*listener*/) noexcept override {}
40 void Unload() noexcept override {}
42 void Release() noexcept override {
46 static void AddExt(std::string name, ext_factory factory);
48 static void AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl);
50 static std::shared_ptr<ExtensionsHolder> GetExtensionsHolder();
54 void collectTypes(char**& types, unsigned int& size, const std::map<std::string, T> &factories);
57 template<typename Ext>
58 class ExtRegisterBase {
60 explicit ExtRegisterBase(const std::string& type) {
61 CpuExtensions::AddExt(type,
62 [](const CNNLayer* layer) -> InferenceEngine::ILayerImplFactory* {
63 return new Ext(layer);
68 #define REG_FACTORY_FOR(__prim, __type) \
69 static ExtRegisterBase<__prim> __reg__##__type(#__type)
71 template<typename Impl>
72 class ShapeInferImplRegister {
74 explicit ShapeInferImplRegister(const std::string& type) {
75 CpuExtensions::AddShapeInferImpl(type, std::make_shared<Impl>());
79 #define REG_SHAPE_INFER_FOR_TYPE(__impl, __type) \
80 static ShapeInferImplRegister<__impl> __reg__si__##__type(#__type)
83 } // namespace Extensions
84 } // namespace InferenceEngine