1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
11 #include <inference_engine.hpp>
13 #define CUSTOM_RELU_TYPE std::string("CustomReLU")
15 class CustomReLUImpl : public InferenceEngine::ILayerExecImpl {
17 explicit CustomReLUImpl(const InferenceEngine::CNNLayer& layer) : _layer(layer) {}
19 InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf,
20 InferenceEngine::ResponseDesc* resp) noexcept override {
21 InferenceEngine::DataConfig inDataConfig;
22 InferenceEngine::DataConfig outDataConfig;
23 auto firstInput = *_layer.insData.begin();
24 auto firstOutput = *_layer.outData.begin();
25 inDataConfig.desc = firstInput.lock()->getTensorDesc();
26 outDataConfig.desc = firstOutput->getTensorDesc();
27 InferenceEngine::LayerConfig layerConfig;
28 layerConfig.inConfs = {inDataConfig};
29 layerConfig.outConfs = {outDataConfig};
30 conf.push_back(layerConfig);
31 return InferenceEngine::StatusCode::OK;
34 InferenceEngine::StatusCode
35 init(InferenceEngine::LayerConfig& config, InferenceEngine::ResponseDesc* resp) noexcept override {
36 return InferenceEngine::StatusCode::OK;
39 InferenceEngine::StatusCode
40 execute(std::vector<InferenceEngine::Blob::Ptr>& inputs, std::vector<InferenceEngine::Blob::Ptr>& outputs,
41 InferenceEngine::ResponseDesc* resp) noexcept override {
42 static bool wasCalled = false;
44 std::cout << "Running " + CUSTOM_RELU_TYPE + " kernel for the first time (next messages won't be printed)"
48 for (size_t i = 0; i < inputs.size(); i++) {
49 auto inputBlob = inputs[i];
50 auto outputBlob = outputs[i];
51 auto inputData = inputBlob->buffer().as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::FP32>::value_type*>();
52 auto outputData = outputBlob->buffer().as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::FP32>::value_type*>();
53 for (size_t j = 0; j < inputBlob->size(); j++) {
54 outputData[j] = inputData[j] < 0 ? 0 : inputData[j];
57 return InferenceEngine::StatusCode::OK;
61 const InferenceEngine::CNNLayer _layer;
64 class CustomReLUFactory : public InferenceEngine::ILayerImplFactory {
66 explicit CustomReLUFactory(const InferenceEngine::CNNLayer* layer) : _layer(*layer) {}
68 InferenceEngine::StatusCode
69 getImplementations(std::vector<InferenceEngine::ILayerImpl::Ptr>& impls,
70 InferenceEngine::ResponseDesc* resp) noexcept override {
71 impls.push_back(std::make_shared<CustomReLUImpl>(_layer));
72 return InferenceEngine::StatusCode::OK;
76 InferenceEngine::CNNLayer _layer;
79 class CustomReLUResizeImpl : public InferenceEngine::IShapeInferImpl {
81 InferenceEngine::StatusCode inferShapes(const std::vector<InferenceEngine::Blob::CPtr>& inBlobs,
82 const std::map<std::string, std::string>& params,
83 const std::map<std::string, InferenceEngine::Blob::Ptr>& blobs,
84 std::vector<InferenceEngine::SizeVector>& outShapes,
85 InferenceEngine::ResponseDesc* desc) noexcept override {
86 static bool wasCalled = false;
88 std::cout << "Running " + CUSTOM_RELU_TYPE +
89 " shape inference for the first time (next messages won't be printed)" << std::endl;
92 for (const auto& blob : inBlobs) {
93 outShapes.push_back(blob->getTensorDesc().getDims());
95 return InferenceEngine::StatusCode::OK;
99 class InPlaceExtension : public InferenceEngine::IExtension {
102 _shapeInferImpl = std::make_shared<CustomReLUResizeImpl>();
105 InferenceEngine::StatusCode
106 getPrimitiveTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
108 types = new char* [size];
109 std::string type = CUSTOM_RELU_TYPE;
110 types[0] = new char[type.size() + 1];
111 std::copy(type.begin(), type.end(), types[0]);
112 types[0][type.size()] = 0;
113 return InferenceEngine::OK;
116 InferenceEngine::StatusCode
117 getShapeInferTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
118 return getPrimitiveTypes(types, size, resp);
121 InferenceEngine::StatusCode getShapeInferImpl(InferenceEngine::IShapeInferImpl::Ptr& impl, const char* type,
122 InferenceEngine::ResponseDesc* resp) noexcept override {
123 if (CUSTOM_RELU_TYPE.compare(type) != 0) return InferenceEngine::StatusCode::NOT_IMPLEMENTED;
124 impl = _shapeInferImpl;
125 return InferenceEngine::StatusCode::OK;
128 void GetVersion(const InferenceEngine::Version*& versionInfo) const noexcept override {};
130 void SetLogCallback(InferenceEngine::IErrorListener& listener) noexcept override {};
132 void Unload() noexcept override {};
134 void Release() noexcept override {}
136 InferenceEngine::StatusCode
137 getFactoryFor(InferenceEngine::ILayerImplFactory*& factory, const InferenceEngine::CNNLayer* cnnLayer,
138 InferenceEngine::ResponseDesc* resp) noexcept override {
139 if (cnnLayer->type != CUSTOM_RELU_TYPE)
140 return InferenceEngine::StatusCode::NOT_IMPLEMENTED;
141 factory = new CustomReLUFactory(cnnLayer);
142 return InferenceEngine::StatusCode::OK;
146 InferenceEngine::IShapeInferImpl::Ptr _shapeInferImpl;