Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_topkrois_onnx.cpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7 #include <algorithm>
8 #include <cassert>
9 #include <vector>
10
11
12 namespace InferenceEngine {
13 namespace Extensions {
14 namespace Cpu {
15
16 class ExperimentalDetectronTopKROIsImpl: public ExtLayerBase {
17 private:
18     // Inputs:
19     //      rois, shape [n, 4]
20     //      rois_probs, shape [n]
21     // Outputs:
22     //      top_rois, shape [max_rois, 4]
23
24     const int INPUT_ROIS {0};
25     const int INPUT_PROBS {1};
26
27     const int OUTPUT_ROIS {0};
28
29 public:
30     explicit ExperimentalDetectronTopKROIsImpl(const CNNLayer* layer) {
31         try {
32             if (layer->insData.size() != 2 || layer->outData.empty())
33                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
34
35             if (layer->insData[INPUT_ROIS].lock()->dims.size() != 2 ||
36                 layer->insData[INPUT_PROBS].lock()->dims.size() != 1)
37                 THROW_IE_EXCEPTION << "Unsupported shape of input blobs!";
38
39             max_rois_num_ = layer->GetParamAsInt("max_rois", 0);
40
41             addConfig(layer,
42                       {DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN)},
43                       {DataConfigurator(ConfLayout::PLN)});
44         } catch (InferenceEngine::details::InferenceEngineException &ex) {
45             errorMsg = ex.what();
46         }
47     }
48
49     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
50                        ResponseDesc *resp) noexcept override {
51         const int input_rois_num = inputs[INPUT_ROIS]->getTensorDesc().getDims()[0];
52         const int top_rois_num = std::min(max_rois_num_, input_rois_num);
53
54         auto *input_rois = inputs[INPUT_ROIS]->buffer().as<const float *>();
55         auto *input_probs = inputs[INPUT_PROBS]->buffer().as<const float *>();
56         auto *output_rois = outputs[OUTPUT_ROIS]->buffer().as<float *>();
57
58         std::vector<size_t> idx(input_rois_num);
59         iota(idx.begin(), idx.end(), 0);
60         // FIXME. partial_sort is enough here.
61         sort(idx.begin(), idx.end(), [&input_probs](size_t i1, size_t i2) {return input_probs[i1] > input_probs[i2];});
62
63         for (int i = 0; i < top_rois_num; ++i) {
64             std::memcpy(output_rois + 4 * i, input_rois + 4 * idx[i], 4 * sizeof(float));
65         }
66
67         return OK;
68     }
69
70 private:
71     int max_rois_num_;
72 };
73
74 REG_FACTORY_FOR(ImplFactory<ExperimentalDetectronTopKROIsImpl>, ExperimentalDetectronTopKROIs);
75
76 }  // namespace Cpu
77 }  // namespace Extensions
78 }  // namespace InferenceEngine