Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_list.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_iextension.h>
8
9 #include <string>
10 #include <map>
11 #include <memory>
12 #include <algorithm>
13
14 namespace InferenceEngine {
15 namespace Extensions {
16 namespace Cpu {
17
18 using ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer*)>;
19
20 struct ExtensionsHolder {
21     std::map<std::string, ext_factory> list;
22     std::map<std::string, IShapeInferImpl::Ptr> si_list;
23 };
24
25 class INFERENCE_ENGINE_API_CLASS(CpuExtensions) : public IExtension {
26 public:
27     StatusCode getPrimitiveTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override;
28
29     StatusCode
30     getFactoryFor(ILayerImplFactory*& factory, const CNNLayer* cnnLayer, ResponseDesc* resp) noexcept override;
31
32     StatusCode getShapeInferTypes(char**& types, unsigned int& size, ResponseDesc* resp) noexcept override;
33
34     StatusCode getShapeInferImpl(IShapeInferImpl::Ptr& impl, const char* type, ResponseDesc* resp) noexcept override;
35
36     void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override;
37
38     void SetLogCallback(InferenceEngine::IErrorListener& /*listener*/) noexcept override {}
39
40     void Unload() noexcept override {}
41
42     void Release() noexcept override {
43         delete this;
44     }
45
46     static void AddExt(std::string name, ext_factory factory);
47
48     static void AddShapeInferImpl(std::string name, const IShapeInferImpl::Ptr& impl);
49
50     static std::shared_ptr<ExtensionsHolder> GetExtensionsHolder();
51
52 private:
53     template<class T>
54     void collectTypes(char**& types, unsigned int& size, const std::map<std::string, T> &factories);
55 };
56
57 template<typename Ext>
58 class ExtRegisterBase {
59 public:
60     explicit ExtRegisterBase(const std::string& type) {
61         CpuExtensions::AddExt(type,
62                               [](const CNNLayer* layer) -> InferenceEngine::ILayerImplFactory* {
63                                   return new Ext(layer);
64                               });
65     }
66 };
67
68 #define REG_FACTORY_FOR(__prim, __type) \
69 static ExtRegisterBase<__prim> __reg__##__type(#__type)
70
71 template<typename Impl>
72 class ShapeInferImplRegister {
73 public:
74     explicit ShapeInferImplRegister(const std::string& type) {
75         CpuExtensions::AddShapeInferImpl(type, std::make_shared<Impl>());
76     }
77 };
78
79 #define REG_SHAPE_INFER_FOR_TYPE(__impl, __type) \
80 static ShapeInferImplRegister<__impl> __reg__si__##__type(#__type)
81
82 }  // namespace Cpu
83 }  // namespace Extensions
84 }  // namespace InferenceEngine