Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_priorgridgenerator_onnx.cpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7 #include <algorithm>
8 #include <cassert>
9 #include <vector>
10
11 namespace InferenceEngine {
12 namespace Extensions {
13 namespace Cpu {
14
15 const int INPUT_PRIORS {0};
16 const int INPUT_FEATUREMAP {1};
17 const int INPUT_IMAGE {2};
18
19 const int OUTPUT_ROIS {0};
20
21 class ExperimentalDetectronPriorGridGeneratorImpl: public ExtLayerBase {
22 private:
23     // Inputs:
24     //      priors, shape [n, 4]
25     //      [feature_map], shape [b, c, h, w]
26     //      [im_data], shape [b, 3, im_h, im_w]
27     // Outputs:
28     //      priors_grid, shape [m, 4]
29
30 public:
31     explicit ExperimentalDetectronPriorGridGeneratorImpl(const CNNLayer* layer) {
32         try {
33             if (layer->insData.size() > 3 || layer->outData.empty())
34                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
35
36             if (layer->insData[INPUT_PRIORS].lock()->dims.size() != 2 ||
37                     (layer->insData.size() > INPUT_FEATUREMAP &&
38                      layer->insData[INPUT_FEATUREMAP].lock()->dims.size() != 4) ||
39                     (layer->insData.size() > INPUT_IMAGE &&
40                      layer->insData[INPUT_IMAGE].lock()->dims.size() != 4))
41                 THROW_IE_EXCEPTION << "Unsupported shape of input blobs!";
42
43             grid_w_ = layer->GetParamAsInt("w", 0);
44             grid_h_ = layer->GetParamAsInt("h", 0);
45             stride_h_ = layer->GetParamAsFloat("stride_y", 0);
46             stride_w_ = layer->GetParamAsFloat("stride_x", 0);
47
48             addConfig(layer,
49                       {DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN)},
50                       {DataConfigurator(ConfLayout::PLN)});
51         } catch (InferenceEngine::details::InferenceEngineException &ex) {
52             errorMsg = ex.what();
53         }
54     }
55
56     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
57                        ResponseDesc *resp) noexcept override {
58         const int num_priors_ = inputs[INPUT_PRIORS]->getTensorDesc().getDims()[0];
59         assert(inputs[INPUT_PRIORS]->getTensorDesc().getDims()[1] == 4);
60
61         // Execute
62         const int layer_width = grid_w_ ? grid_w_ : inputs[INPUT_FEATUREMAP]->getTensorDesc().getDims()[3];
63         const int layer_height = grid_h_ ? grid_h_ : inputs[INPUT_FEATUREMAP]->getTensorDesc().getDims()[2];
64         const float step_w = stride_w_ ? stride_w_ : static_cast<float>(inputs[INPUT_IMAGE]->getTensorDesc().getDims()[3]) / layer_width;
65         const float step_h = stride_h_ ? stride_h_ : static_cast<float>(inputs[INPUT_IMAGE]->getTensorDesc().getDims()[2]) / layer_height;
66
67         const auto *bottom_data_0 = inputs[0]->buffer().as<const float *>();
68         auto *top_data_0 = outputs[OUTPUT_ROIS]->buffer().as<float *>();
69
70         for (int h = 0; h < layer_height; ++h) {
71             for (int w = 0; w < layer_width; ++w) {
72                 for (int s = 0; s < num_priors_; ++s) {
73                     top_data_0[0] = bottom_data_0[4 * s + 0] + step_w * (w + 0.5f);
74                     top_data_0[1] = bottom_data_0[4 * s + 1] + step_h * (h + 0.5f);
75                     top_data_0[2] = bottom_data_0[4 * s + 2] + step_w * (w + 0.5f);
76                     top_data_0[3] = bottom_data_0[4 * s + 3] + step_h * (h + 0.5f);
77                     top_data_0 += 4;
78                 }
79             }
80         }
81
82         return OK;
83     }
84
85 private:
86     int grid_w_;
87     int grid_h_;
88     float stride_w_;
89     float stride_h_;
90 };
91
92
93 REG_FACTORY_FOR(ImplFactory<ExperimentalDetectronPriorGridGeneratorImpl>, ExperimentalDetectronPriorGridGenerator);
94
95 }  // namespace Cpu
96 }  // namespace Extensions
97 }  // namespace InferenceEngine