1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
7 #include <ie_builders.hpp>
8 #include <builders/ie_argmax_layer.hpp>
10 #include "builder_test.hpp"
12 using namespace testing;
13 using namespace InferenceEngine;
15 class ArgMaxLayerBuilderTest : public BuilderTestCommon {};
17 TEST_F(ArgMaxLayerBuilderTest, getExistsLayerFromNetworkBuilder) {
18 Builder::Network network("network");
19 Builder::ArgMaxLayer argMaxLayer("ArgMax layer");
20 argMaxLayer.setAxis(1);
21 argMaxLayer.setOutMaxVal(0);
22 argMaxLayer.setTopK(20);
24 ASSERT_NO_THROW(ind = network.addLayer(argMaxLayer));
25 Builder::ArgMaxLayer layerFromNetwork(network.getLayer(ind));
26 ASSERT_EQ(argMaxLayer.getAxis(), layerFromNetwork.getAxis());
27 ASSERT_EQ(argMaxLayer.getOutMaxVal(), layerFromNetwork.getOutMaxVal());
28 ASSERT_EQ(argMaxLayer.getTopK(), layerFromNetwork.getTopK());
31 TEST_F(ArgMaxLayerBuilderTest, cannotAddLayerWithWrongAxis) {
32 Builder::Network network("network");
33 Builder::ArgMaxLayer argMaxLayer("ArgMax layer");
34 argMaxLayer.setAxis(500); // here
35 argMaxLayer.setOutMaxVal(0);
36 argMaxLayer.setTopK(20);
37 ASSERT_THROW(network.addLayer(argMaxLayer), InferenceEngine::details::InferenceEngineException);
40 TEST_F(ArgMaxLayerBuilderTest, cannotAddLayerWithWrongOutMaxVal) {
41 Builder::Network network("network");
42 Builder::ArgMaxLayer argMaxLayer("ArgMax layer");
43 argMaxLayer.setAxis(1);
44 argMaxLayer.setOutMaxVal(500); // here
45 argMaxLayer.setTopK(20);
46 ASSERT_THROW(network.addLayer(argMaxLayer), InferenceEngine::details::InferenceEngineException);