1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include "ext_list.hpp"
7 #include "ext_base.hpp"
16 namespace InferenceEngine {
17 namespace Extensions {
20 class ArgMaxImpl: public ExtLayerBase {
22 explicit ArgMaxImpl(const CNNLayer* layer) {
24 if (layer->insData.size() != 1 || layer->outData.empty())
25 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
27 out_max_val_ = static_cast<bool>(layer->GetParamAsInt("out_max_val"));
28 top_k_ = layer->GetParamAsInt("top_k");
30 has_axis_ = (layer->params.find("axis") != layer->params.end());
31 axis_index_ = has_axis_ ?
32 std::stoi(layer->params.at("axis")) :0;
34 addConfig(layer, {DataConfigurator(ConfLayout::PLN)}, {DataConfigurator(ConfLayout::PLN)});
35 } catch (InferenceEngine::details::InferenceEngineException &ex) {
40 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
41 ResponseDesc *resp) noexcept override {
42 SizeVector in_dims = inputs[0]->getTensorDesc().getDims();
43 SizeVector out_dims = outputs[0]->getTensorDesc().getDims();
47 int axis_ = (axis_index_ < 0) ? axis_index_ + static_cast<int>(in_dims.size()) : axis_index_;
48 dim = static_cast<int>(inputs[0]->getTensorDesc().getDims()[axis_]);
49 axis_dist = count(inputs[0]->getTensorDesc().getDims(), axis_) / dim;
51 dim = count(inputs[0]->getTensorDesc().getDims(), 1);
55 float* src_data = inputs[0]->buffer();
56 float* dst_data = outputs[0]->buffer();
58 int num = count(in_dims) / dim;
59 std::vector<std::pair<float, int> > src_vector(dim);
61 for (int i = 0; i < num; ++i) {
62 for (int j = 0; j < dim; ++j) {
63 src_vector[j] = std::make_pair(
64 src_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
67 std::partial_sort(src_vector.begin(), src_vector.begin() + top_k_,
68 src_vector.end(), std::greater<std::pair<float, int> >());
70 for (int j = 0; j < top_k_; ++j) {
73 // Produces max_val per axis
74 dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = src_vector[j].first;
76 // Produces max_ind and max_val
77 dst_data[2 * i * top_k_ + j] = src_vector[j].second;
78 dst_data[2 * i * top_k_ + top_k_ + j] = src_vector[j].first;
81 // Produces max_ind per axis
82 dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = src_vector[j].second;
96 inline int count(SizeVector dims, size_t start_ind, size_t end_ind) {
98 for (size_t i = start_ind; i < end_ind; i++)
100 return static_cast<int>(count);
103 inline int count(SizeVector dims, size_t start_ind = 0) {
104 return count(dims, start_ind, dims.size());
108 REG_FACTORY_FOR(ImplFactory<ArgMaxImpl>, ArgMax);
111 } // namespace Extensions
112 } // namespace InferenceEngine