1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ext_list.hpp"
6 #include "ext_base.hpp"
15 namespace InferenceEngine {
16 namespace Extensions {
19 class ArgMaxImpl: public ExtLayerBase {
21 explicit ArgMaxImpl(const CNNLayer* layer) {
23 if (layer->insData.size() != 1 || layer->outData.empty())
24 THROW_IE_EXCEPTION << "Incorrect number of input/output edges!";
26 out_max_val_ = layer->GetParamAsBool("out_max_val", false);
27 top_k_ = layer->GetParamAsInt("top_k");
29 has_axis_ = (layer->params.find("axis") != layer->params.end());
30 axis_index_ = has_axis_ ?
31 std::stoi(layer->params.at("axis")) :0;
33 addConfig(layer, {DataConfigurator(ConfLayout::PLN)}, {DataConfigurator(ConfLayout::PLN)});
34 } catch (InferenceEngine::details::InferenceEngineException &ex) {
39 StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs,
40 ResponseDesc *resp) noexcept override {
41 SizeVector in_dims = inputs[0]->getTensorDesc().getDims();
42 SizeVector out_dims = outputs[0]->getTensorDesc().getDims();
46 int axis_ = (axis_index_ < 0) ? axis_index_ + static_cast<int>(in_dims.size()) : axis_index_;
47 dim = static_cast<int>(inputs[0]->getTensorDesc().getDims()[axis_]);
48 axis_dist = count(inputs[0]->getTensorDesc().getDims(), axis_) / dim;
50 dim = count(inputs[0]->getTensorDesc().getDims(), 1);
54 float* src_data = inputs[0]->buffer();
55 float* dst_data = outputs[0]->buffer();
57 int num = count(in_dims) / dim;
58 std::vector<std::pair<float, int> > src_vector(dim);
60 for (int i = 0; i < num; ++i) {
61 for (int j = 0; j < dim; ++j) {
62 src_vector[j] = std::make_pair(
63 src_data[(i / axis_dist * dim + j) * axis_dist + i % axis_dist], j);
66 std::partial_sort(src_vector.begin(), src_vector.begin() + top_k_,
67 src_vector.end(), std::greater<std::pair<float, int> >());
69 for (int j = 0; j < top_k_; ++j) {
72 // Produces max_val per axis
73 dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = src_vector[j].first;
75 // Produces max_ind and max_val
76 dst_data[2 * i * top_k_ + j] = static_cast<float>(src_vector[j].second);
77 dst_data[2 * i * top_k_ + top_k_ + j] = src_vector[j].first;
80 // Produces max_ind per axis
81 dst_data[(i / axis_dist * top_k_ + j) * axis_dist + i % axis_dist] = static_cast<float>(src_vector[j].second);
95 inline int count(SizeVector dims, size_t start_ind, size_t end_ind) {
97 for (size_t i = start_ind; i < end_ind; i++)
99 return static_cast<int>(count);
102 inline int count(SizeVector dims, size_t start_ind = 0) {
103 return count(dims, start_ind, dims.size());
107 REG_FACTORY_FOR(ImplFactory<ArgMaxImpl>, ArgMax);
110 } // namespace Extensions
111 } // namespace InferenceEngine