Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_argmax.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
7
8 #include <algorithm>
9 #include <string>
10 #include <vector>
11 #include <cmath>
12 #include <utility>
13 #include <functional>
14
15 namespace InferenceEngine {
16 namespace Extensions {
17 namespace Cpu {
18
19 class ArgMaxImpl: public ExtLayerBase {
20 public:
21     explicit ArgMaxImpl(const CNNLayer* layer) {
22         try {
23             if (layer->insData.size() != 1 || layer->outData.empty())
24                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
25
26             out_max_val_ = layer->GetParamAsBool("out_max_val", false);
27             top_k_       = layer->GetParamAsInt("top_k");
28
29             has_axis_ = (layer->params.find("axis") != layer->params.end());
30             axis_index_ = has_axis_ ?
31                                 std::stoi(layer->params.at("axis")) :0;
32
33             addConfig(layer, {DataConfigurator(ConfLayout::PLN)}, {DataConfigurator(ConfLayout::PLN)});
34         } catch (InferenceEngine::details::InferenceEngineException &ex) {
35             errorMsg = ex.what();
36         }
37     }
38
39     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
40                        ResponseDesc *resp) noexcept override {
41         SizeVector in_dims = inputs[0]->getTensorDesc().getDims();
42         SizeVector out_dims = outputs[0]->getTensorDesc().getDims();
43
44         int dim, axis_dist;
45         if (has_axis_) {
46             int axis_ = (axis_index_ < 0) ? axis_index_ + static_cast<int>(in_dims.size()) : axis_index_;
47             dim = static_cast<int>(inputs[0]->getTensorDesc().getDims()[axis_]);
48             axis_dist = count(inputs[0]->getTensorDesc().getDims(), axis_) / dim;
49         } else {
50             dim = count(inputs[0]->getTensorDesc().getDims(), 1);
51             axis_dist = 1;
52         }
53
54         float* src_data = inputs[0]->buffer();
55         float* dst_data = outputs[0]->buffer();
56
57         int num = count(in_dims) / dim;
58         std::vector<std::pair<float, int> > src_vector(dim);
59
60         for (int i = 0; i < num; ++i) {
61             for (int j = 0; j < dim; ++j) {
62                 src_vector[j] = std::make_pair(
63                         src_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
64             }
65
66             std::partial_sort(src_vector.begin(), src_vector.begin() + top_k_,
67                               src_vector.end(), std::greater<std::pair<float, int> >());
68
69             for (int j = 0; j < top_k_; ++j) {
70                 if (out_max_val_) {
71                     if (has_axis_) {
72                         // Produces max_val per axis
73                         dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = src_vector[j].first;
74                     } else {
75                         // Produces max_ind and max_val
76                         dst_data[2 * i * top_k_ + j] = static_cast<float>(src_vector[j].second);
77                         dst_data[2 * i * top_k_ + top_k_ + j] = src_vector[j].first;
78                     }
79                 } else {
80                     // Produces max_ind per axis
81                     dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = static_cast<float>(src_vector[j].second);
82                 }
83             }
84         }
85
86         return OK;
87     }
88
89 private:
90     bool out_max_val_;
91     int top_k_;
92     bool has_axis_;
93     int axis_index_;
94
95     inline int count(SizeVector dims, size_t start_ind, size_t end_ind) {
96         size_t count = 1;
97         for (size_t i = start_ind; i < end_ind; i++)
98             count *= dims[i];
99         return static_cast<int>(count);
100     }
101
102     inline int count(SizeVector dims, size_t start_ind = 0) {
103         return count(dims, start_ind, dims.size());
104     }
105 };
106
107 REG_FACTORY_FOR(ImplFactory<ArgMaxImpl>, ArgMax);
108
109 }  // namespace Cpu
110 }  // namespace Extensions
111 }  // namespace InferenceEngine