Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / samples / hello_shape_infer_ssd / shape_infer_extension.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <map>
6 #include <memory>
7 #include <string>
8 #include <algorithm>
9 #include <vector>
10
11 #include <inference_engine.hpp>
12
13 #define CUSTOM_RELU_TYPE std::string("CustomReLU")
14
15 class CustomReLUImpl : public InferenceEngine::ILayerExecImpl {
16 public:
17     explicit CustomReLUImpl(const InferenceEngine::CNNLayer& layer) : _layer(layer) {}
18
19     InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf,
20                                                            InferenceEngine::ResponseDesc* resp) noexcept override {
21         InferenceEngine::DataConfig inDataConfig;
22         InferenceEngine::DataConfig outDataConfig;
23         auto firstInput = *_layer.insData.begin();
24         auto firstOutput = *_layer.outData.begin();
25         inDataConfig.desc = firstInput.lock()->getTensorDesc();
26         outDataConfig.desc = firstOutput->getTensorDesc();
27         InferenceEngine::LayerConfig layerConfig;
28         layerConfig.inConfs = {inDataConfig};
29         layerConfig.outConfs = {outDataConfig};
30         conf.push_back(layerConfig);
31         return InferenceEngine::StatusCode::OK;
32     }
33
34     InferenceEngine::StatusCode
35     init(InferenceEngine::LayerConfig& config, InferenceEngine::ResponseDesc* resp) noexcept override {
36         return InferenceEngine::StatusCode::OK;
37     }
38
39     InferenceEngine::StatusCode
40     execute(std::vector<InferenceEngine::Blob::Ptr>& inputs, std::vector<InferenceEngine::Blob::Ptr>& outputs,
41             InferenceEngine::ResponseDesc* resp) noexcept override {
42         static bool wasCalled = false;
43         if (!wasCalled) {
44             std::cout << "Running " + CUSTOM_RELU_TYPE + " kernel for the first time (next messages won't be printed)"
45                       << std::endl;
46             wasCalled = true;
47         }
48         for (size_t i = 0; i < inputs.size(); i++) {
49             auto inputBlob = inputs[i];
50             auto outputBlob = outputs[i];
51             auto inputData = inputBlob->buffer().as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::FP32>::value_type*>();
52             auto outputData = outputBlob->buffer().as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::FP32>::value_type*>();
53             for (size_t j = 0; j < inputBlob->size(); j++) {
54                 outputData[j] = inputData[j] < 0 ? 0 : inputData[j];
55             }
56         }
57         return InferenceEngine::StatusCode::OK;
58     }
59
60 private:
61     const InferenceEngine::CNNLayer _layer;
62 };
63
64 class CustomReLUFactory : public InferenceEngine::ILayerImplFactory {
65 public:
66     explicit CustomReLUFactory(const InferenceEngine::CNNLayer* layer) : _layer(*layer) {}
67
68     InferenceEngine::StatusCode
69     getImplementations(std::vector<InferenceEngine::ILayerImpl::Ptr>& impls,
70                        InferenceEngine::ResponseDesc* resp) noexcept override {
71         impls.push_back(std::make_shared<CustomReLUImpl>(_layer));
72         return InferenceEngine::StatusCode::OK;
73     }
74
75 private:
76     InferenceEngine::CNNLayer _layer;
77 };
78
79 class CustomReLUResizeImpl : public InferenceEngine::IShapeInferImpl {
80 public:
81     InferenceEngine::StatusCode inferShapes(const std::vector<InferenceEngine::Blob::CPtr>& inBlobs,
82                                             const std::map<std::string, std::string>& params,
83                                             const std::map<std::string, InferenceEngine::Blob::Ptr>& blobs,
84                                             std::vector<InferenceEngine::SizeVector>& outShapes,
85                                             InferenceEngine::ResponseDesc* desc) noexcept override {
86         static bool wasCalled = false;
87         if (!wasCalled) {
88             std::cout << "Running " + CUSTOM_RELU_TYPE +
89                          " shape inference for the first time (next messages won't be printed)" << std::endl;
90             wasCalled = true;
91         }
92         for (const auto& blob : inBlobs) {
93             outShapes.push_back(blob->getTensorDesc().getDims());
94         }
95         return InferenceEngine::StatusCode::OK;
96     }
97 };
98
99 class InPlaceExtension : public InferenceEngine::IExtension {
100 public:
101     InPlaceExtension() {
102         _shapeInferImpl = std::make_shared<CustomReLUResizeImpl>();
103     }
104
105     InferenceEngine::StatusCode
106     getPrimitiveTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
107         size = 1;
108         types = new char* [size];
109         std::string type = CUSTOM_RELU_TYPE;
110         types[0] = new char[type.size() + 1];
111         std::copy(type.begin(), type.end(), types[0]);
112         types[0][type.size()] = 0;
113         return InferenceEngine::OK;
114     };
115
116     InferenceEngine::StatusCode
117     getShapeInferTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
118         return getPrimitiveTypes(types, size, resp);
119     };
120
121     InferenceEngine::StatusCode getShapeInferImpl(InferenceEngine::IShapeInferImpl::Ptr& impl, const char* type,
122                                                   InferenceEngine::ResponseDesc* resp) noexcept override {
123         if (CUSTOM_RELU_TYPE.compare(type) != 0) return InferenceEngine::StatusCode::NOT_IMPLEMENTED;
124         impl = _shapeInferImpl;
125         return InferenceEngine::StatusCode::OK;
126     }
127
128     void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override {};
129
130     void SetLogCallback(InferenceEngine::IErrorListener& listener) noexcept override {};
131
132     void Unload() noexcept override {};
133
134     void Release() noexcept override {}
135
136     InferenceEngine::StatusCode
137     getFactoryFor(InferenceEngine::ILayerImplFactory*& factory, const InferenceEngine::CNNLayer* cnnLayer,
138                   InferenceEngine::ResponseDesc* resp) noexcept override {
139         if (cnnLayer->type != CUSTOM_RELU_TYPE)
140             return InferenceEngine::StatusCode::NOT_IMPLEMENTED;
141         factory = new CustomReLUFactory(cnnLayer);
142         return InferenceEngine::StatusCode::OK;
143     };
144
145 private:
146     InferenceEngine::IShapeInferImpl::Ptr _shapeInferImpl;
147 };