1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <builders/ie_argmax_layer.hpp>
6 #include <ie_cnn_layer_builder.h>
11 using namespace InferenceEngine;
13 Builder::ArgMaxLayer::ArgMaxLayer(const std::string& name): LayerDecorator("ArgMax", name) {
14 getLayer()->getOutputPorts().resize(1);
15 getLayer()->getInputPorts().resize(1);
18 Builder::ArgMaxLayer::ArgMaxLayer(const Layer::Ptr& layer): LayerDecorator(layer) {
22 Builder::ArgMaxLayer::ArgMaxLayer(const Layer::CPtr& layer): LayerDecorator(layer) {
26 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setName(const std::string& name) {
27 getLayer()->setName(name);
31 const Port& Builder::ArgMaxLayer::getPort() const {
32 return getLayer()->getInputPorts()[0];
35 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setPort(const Port &port) {
36 getLayer()->getInputPorts()[0] = port;
37 getLayer()->getOutputPorts()[0] = port;
41 int Builder::ArgMaxLayer::getAxis() const {
42 return getLayer()->getParameters().at("axis");
44 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setAxis(int axis) {
45 getLayer()->getParameters()["axis"] = axis;
48 size_t Builder::ArgMaxLayer::getTopK() const {
49 return getLayer()->getParameters().at("top_k");
51 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setTopK(size_t topK) {
52 getLayer()->getParameters()["top_k"] = topK;
55 size_t Builder::ArgMaxLayer::getOutMaxVal() const {
56 return getLayer()->getParameters().at("out_max_val");
58 Builder::ArgMaxLayer& Builder::ArgMaxLayer::setOutMaxVal(size_t outMaxVal) {
59 getLayer()->getParameters()["out_max_val"] = outMaxVal;
63 REG_VALIDATOR_FOR(ArgMax, [] (const InferenceEngine::Builder::Layer::CPtr& input_layer, bool partial) {
64 if (!input_layer->getInputPorts().empty() &&
65 !input_layer->getOutputPorts().empty() &&
66 !input_layer->getInputPorts()[0].shape().empty() &&
67 !input_layer->getOutputPorts()[0].shape().empty() &&
68 input_layer->getInputPorts()[0].shape() != input_layer->getOutputPorts()[0].shape()) {
69 THROW_IE_EXCEPTION << "Input and output ports should be equal";
71 Builder::ArgMaxLayer layer(input_layer);
72 if (layer.getAxis() > 1) {
73 THROW_IE_EXCEPTION << "axis supports only 0 and 1 values.";
75 if (layer.getOutMaxVal() > 1) {
76 THROW_IE_EXCEPTION << "OutMaxVal supports only 0 and 1 values.";
80 REG_CONVERTER_FOR(ArgMax, [](const CNNLayerPtr& cnnLayer, Builder::Layer& layer) {
81 layer.getParameters()["axis"] = cnnLayer->GetParamAsInt("axis");
82 layer.getParameters()["top_k"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("top_k"));
83 layer.getParameters()["out_max_val"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("out_max_val"));