1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
12 namespace InferenceEngine {
13 namespace Extensions {
16 class ExperimentalDetectronTopKROIsImpl: public ExtLayerBase {
20 // rois_probs, shape [n]
22 // top_rois, shape [max_rois, 4]
24 const int INPUT_ROIS {0};
25 const int INPUT_PROBS {1};
27 const int OUTPUT_ROIS {0};
30 explicit ExperimentalDetectronTopKROIsImpl(const CNNLayer* layer) {
32 if (layer->insData.size() != 2 || layer->outData.empty())
33 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
35 if (layer->insData[INPUT_ROIS].lock()->dims.size() != 2 ||
36 layer->insData[INPUT_PROBS].lock()->dims.size() != 1)
37 THROW_IE_EXCEPTION << "Unsupported shape of input blobs!";
39 max_rois_num_ = layer->GetParamAsInt("max_rois", 0);
42 {DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN)},
43 {DataConfigurator(ConfLayout::PLN)});
44 } catch (InferenceEngine::details::InferenceEngineException &ex) {
49 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
50 ResponseDesc *resp) noexcept override {
51 const int input_rois_num = inputs[INPUT_ROIS]->getTensorDesc().getDims()[0];
52 const int top_rois_num = std::min(max_rois_num_, input_rois_num);
54 auto *input_rois = inputs[INPUT_ROIS]->buffer().as<const float *>();
55 auto *input_probs = inputs[INPUT_PROBS]->buffer().as<const float *>();
56 auto *output_rois = outputs[OUTPUT_ROIS]->buffer().as<float *>();
58 std::vector<size_t> idx(input_rois_num);
59 iota(idx.begin(), idx.end(), 0);
60 // FIXME. partial_sort is enough here.
61 sort(idx.begin(), idx.end(), [&input_probs](size_t i1, size_t i2) {return input_probs[i1] > input_probs[i2];});
63 for (int i = 0; i < top_rois_num; ++i) {
64 std::memcpy(output_rois + 4 * i, input_rois + 4 * idx[i], 4 * sizeof(float));
74 REG_FACTORY_FOR(ImplFactory<ExperimentalDetectronTopKROIsImpl>, ExperimentalDetectronTopKROIs);
77 } // namespace Extensions
78 } // namespace InferenceEngine