12b5057cad86afdefaed6a9e93a1b6c59dcadf02
[platform/upstream/dldt.git] / inference-engine / src / extension / ext_argmax.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "ext_list.hpp"
7 #include "ext_base.hpp"
8
9 #include <algorithm>
10 #include <string>
11 #include <vector>
12 #include <cmath>
13 #include <utility>
14 #include <functional>
15
16 namespace InferenceEngine {
17 namespace Extensions {
18 namespace Cpu {
19
20 class ArgMaxImpl: public ExtLayerBase {
21 public:
22     explicit ArgMaxImpl(const CNNLayer* layer) {
23         try {
24             if (layer->insData.size() != 1 || layer->outData.empty())
25                 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
26
27             out_max_val_ = static_cast<bool>(layer->GetParamAsInt("out_max_val"));
28             top_k_       = layer->GetParamAsInt("top_k");
29
30             has_axis_ = (layer->params.find("axis") != layer->params.end());
31             axis_index_ = has_axis_ ?
32                                 std::stoi(layer->params.at("axis")) :0;
33
34             addConfig(layer, {DataConfigurator(ConfLayout::PLN)}, {DataConfigurator(ConfLayout::PLN)});
35         } catch (InferenceEngine::details::InferenceEngineException &ex) {
36             errorMsg = ex.what();
37         }
38     }
39
40     StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
41                        ResponseDesc *resp) noexcept override {
42         SizeVector in_dims = inputs[0]->getTensorDesc().getDims();
43         SizeVector out_dims = outputs[0]->getTensorDesc().getDims();
44
45         int dim, axis_dist;
46         if (has_axis_) {
47             int axis_ = (axis_index_ < 0) ? axis_index_ + static_cast<int>(in_dims.size()) : axis_index_;
48             dim = static_cast<int>(inputs[0]->getTensorDesc().getDims()[axis_]);
49             axis_dist = count(inputs[0]->getTensorDesc().getDims(), axis_) / dim;
50         } else {
51             dim = count(inputs[0]->getTensorDesc().getDims(), 1);
52             axis_dist = 1;
53         }
54
55         float* src_data = inputs[0]->buffer();
56         float* dst_data = outputs[0]->buffer();
57
58         int num = count(in_dims) / dim;
59         std::vector<std::pair<float, int> > src_vector(dim);
60
61         for (int i = 0; i < num; ++i) {
62             for (int j = 0; j < dim; ++j) {
63                 src_vector[j] = std::make_pair(
64                         src_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
65             }
66
67             std::partial_sort(src_vector.begin(), src_vector.begin() + top_k_,
68                               src_vector.end(), std::greater<std::pair<float, int> >());
69
70             for (int j = 0; j < top_k_; ++j) {
71                 if (out_max_val_) {
72                     if (has_axis_) {
73                         // Produces max_val per axis
74                         dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = src_vector[j].first;
75                     } else {
76                         // Produces max_ind and max_val
77                         dst_data[2 * i * top_k_ + j] = src_vector[j].second;
78                         dst_data[2 * i * top_k_ + top_k_ + j] = src_vector[j].first;
79                     }
80                 } else {
81                     // Produces max_ind per axis
82                     dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = src_vector[j].second;
83                 }
84             }
85         }
86
87         return OK;
88     }
89
90 private:
91     bool out_max_val_;
92     int top_k_;
93     bool has_axis_;
94     int axis_index_;
95
96     inline int count(SizeVector dims, size_t start_ind, size_t end_ind) {
97         size_t count = 1;
98         for (size_t i = start_ind; i < end_ind; i++)
99             count *= dims[i];
100         return static_cast<int>(count);
101     }
102
103     inline int count(SizeVector dims, size_t start_ind = 0) {
104         return count(dims, start_ind, dims.size());
105     }
106 };
107
108 REG_FACTORY_FOR(ImplFactory<ArgMaxImpl>, ArgMax);
109
110 }  // namespace Cpu
111 }  // namespace Extensions
112 }  // namespace InferenceEngine