Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / builders / argmax_layer_test.cpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <string.h>
7 #include <ie_builders.hpp>
8 #include <builders/ie_argmax_layer.hpp>
9
10 #include "builder_test.hpp"
11
12 using namespace testing;
13 using namespace InferenceEngine;
14
15 class ArgMaxLayerBuilderTest : public BuilderTestCommon {};
16
17 TEST_F(ArgMaxLayerBuilderTest, getExistsLayerFromNetworkBuilder) {
18     Builder::Network network("network");
19     Builder::ArgMaxLayer argMaxLayer("ArgMax layer");
20     argMaxLayer.setAxis(1);
21     argMaxLayer.setOutMaxVal(0);
22     argMaxLayer.setTopK(20);
23     size_t ind = 0;
24     ASSERT_NO_THROW(ind = network.addLayer(argMaxLayer));
25     Builder::ArgMaxLayer layerFromNetwork(network.getLayer(ind));
26     ASSERT_EQ(argMaxLayer.getAxis(), layerFromNetwork.getAxis());
27     ASSERT_EQ(argMaxLayer.getOutMaxVal(), layerFromNetwork.getOutMaxVal());
28     ASSERT_EQ(argMaxLayer.getTopK(), layerFromNetwork.getTopK());
29 }
30
31 TEST_F(ArgMaxLayerBuilderTest, cannotAddLayerWithWrongAxis) {
32     Builder::Network network("network");
33     Builder::ArgMaxLayer argMaxLayer("ArgMax layer");
34     argMaxLayer.setAxis(500);  // here
35     argMaxLayer.setOutMaxVal(0);
36     argMaxLayer.setTopK(20);
37     ASSERT_THROW(network.addLayer(argMaxLayer), InferenceEngine::details::InferenceEngineException);
38 }
39
40 TEST_F(ArgMaxLayerBuilderTest, cannotAddLayerWithWrongOutMaxVal) {
41     Builder::Network network("network");
42     Builder::ArgMaxLayer argMaxLayer("ArgMax layer");
43     argMaxLayer.setAxis(1);
44     argMaxLayer.setOutMaxVal(500);  // here
45     argMaxLayer.setTopK(20);
46     ASSERT_THROW(network.addLayer(argMaxLayer), InferenceEngine::details::InferenceEngineException);
47 }