Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / builders / ie_argmax_layer.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <builders/ie_argmax_layer.hpp>
6 #include <ie_cnn_layer_builder.h>
7
8 #include <vector>
9 #include <string>
10
11 using namespace InferenceEngine;
12
13 Builder::ArgMaxLayer::ArgMaxLayer(const std::string& name): LayerDecorator("ArgMax", name) {
14     getLayer()->getOutputPorts().resize(1);
15     getLayer()->getInputPorts().resize(1);
16 }
17
18 Builder::ArgMaxLayer::ArgMaxLayer(const Layer::Ptr& layer): LayerDecorator(layer) {
19     checkType("ArgMax");
20 }
21
22 Builder::ArgMaxLayer::ArgMaxLayer(const Layer::CPtr& layer): LayerDecorator(layer) {
23     checkType("ArgMax");
24 }
25
26 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setName(const std::string& name) {
27     getLayer()->setName(name);
28     return *this;
29 }
30
31 const Port& Builder::ArgMaxLayer::getPort() const {
32     return getLayer()->getInputPorts()[0];
33 }
34
35 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setPort(const Port &port) {
36     getLayer()->getInputPorts()[0] = port;
37     getLayer()->getOutputPorts()[0] = port;
38     return *this;
39 }
40
41 int Builder::ArgMaxLayer::getAxis() const {
42     return getLayer()->getParameters().at("axis");
43 }
44 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setAxis(int axis) {
45     getLayer()->getParameters()["axis"] = axis;
46     return *this;
47 }
48 size_t Builder::ArgMaxLayer::getTopK() const {
49     return getLayer()->getParameters().at("top_k");
50 }
51 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setTopK(size_t topK) {
52     getLayer()->getParameters()["top_k"] = topK;
53     return *this;
54 }
55 size_t Builder::ArgMaxLayer::getOutMaxVal() const {
56     return getLayer()->getParameters().at("out_max_val");
57 }
58 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setOutMaxVal(size_t outMaxVal) {
59     getLayer()->getParameters()["out_max_val"] = outMaxVal;
60     return *this;
61 }
62
63 REG_VALIDATOR_FOR(ArgMax, [] (const InferenceEngine::Builder::Layer::CPtr& input_layer, bool partial) {
64     if (!input_layer->getInputPorts().empty() &&
65         !input_layer->getOutputPorts().empty() &&
66         !input_layer->getInputPorts()[0].shape().empty() &&
67         !input_layer->getOutputPorts()[0].shape().empty() &&
68         input_layer->getInputPorts()[0].shape() != input_layer->getOutputPorts()[0].shape()) {
69         THROW_IE_EXCEPTION << "Input and output ports should be equal";
70     }
71     Builder::ArgMaxLayer layer(input_layer);
72     if (layer.getAxis() > 1) {
73         THROW_IE_EXCEPTION << "axis supports only 0 and 1 values.";
74     }
75     if (layer.getOutMaxVal() > 1) {
76         THROW_IE_EXCEPTION << "OutMaxVal supports only 0 and 1 values.";
77     }
78 });
79
80 REG_CONVERTER_FOR(ArgMax, [](const CNNLayerPtr& cnnLayer, Builder::Layer& layer) {
81     layer.getParameters()["axis"] = cnnLayer->GetParamAsInt("axis");
82     layer.getParameters()["top_k"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("top_k"));
83     layer.getParameters()["out_max_val"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("out_max_val"));
84 });
85
86