Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / builders / concat_layer_test.cpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <string.h>
7 #include <ie_builders.hpp>
8 #include <builders/ie_concat_layer.hpp>
9
10 #include "builder_test.hpp"
11
12 using namespace testing;
13 using namespace InferenceEngine;
14
15 class ConcatLayerBuilderTest : public BuilderTestCommon {};
16
17 TEST_F(ConcatLayerBuilderTest, getExistsLayerFromNetworkBuilderAxis) {
18     Builder::Network network("network");
19     Builder::ConcatLayer layer("concat layer");
20
21     layer.setAxis(0);
22     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 55, 55})});
23     layer.setOutputPort(Port({1 + 3, 2, 55, 55}));
24
25     size_t ind = 0;
26     ASSERT_NO_THROW(ind = network.addLayer(layer));
27     network.getLayer(ind)->validate(false);
28     ASSERT_NO_THROW(network.getLayer(ind)->validate(false));
29     Builder::ConcatLayer layerFromNet(network.getLayer(ind));
30
31     ASSERT_EQ(layer.getAxis(), layerFromNet.getAxis());
32     ASSERT_EQ(layer.getInputPorts(), layerFromNet.getInputPorts());
33     ASSERT_EQ(layer.getOutputPort(), layerFromNet.getOutputPort());
34 }
35
36 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithNoInputPorts) {
37     Builder::Network network("network");
38     Builder::ConcatLayer layer("concat layer");
39
40     layer.setAxis(1);
41     layer.setOutputPort(Port({1, 2 + 4, 55, 55}));
42     // here should be layer.setInputPort(...)
43
44     size_t ind = 0;
45     ASSERT_NO_THROW(ind = network.addLayer(layer));
46     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
47 }
48
49 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithOneInputPort) {
50     Builder::Network network("network");
51     Builder::ConcatLayer layer("concat layer");
52
53     layer.setAxis(1);
54     layer.setInputPorts({Port({1, 2, 55, 55})});  // here
55     layer.setOutputPort(Port({1, 2 + 4, 55, 55}));
56
57     size_t ind = 0;
58     ASSERT_NO_THROW(ind = network.addLayer(layer));
59     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
60 }
61
62 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithWrongAxis) {
63     Builder::Network network("network");
64     Builder::ConcatLayer layer("concat layer");
65
66     layer.setAxis(50);  // here
67     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 55, 55})});
68     layer.setOutputPort(Port({1 + 3, 2, 55, 55}));
69
70     size_t ind = 0;
71     ASSERT_NO_THROW(ind = network.addLayer(layer));
72     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
73 }
74
75 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithUnalignedPorts1) {
76     Builder::Network network("network");
77     Builder::ConcatLayer layer("concat layer");
78
79     layer.setAxis(0);
80     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 55, 55})});
81     layer.setOutputPort(Port({1 + 3, 2, 55, 155}));  // should be {1 + 3, 2, 55, 55}
82
83     size_t ind = 0;
84     ASSERT_NO_THROW(ind = network.addLayer(layer));
85     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
86 }
87
88 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithUnalignedPorts2) {
89     Builder::Network network("network");
90     Builder::ConcatLayer layer("concat layer");
91
92     layer.setAxis(0);
93     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 55, 55})});
94     layer.setOutputPort(Port({1 + 3, 2, 155, 55}));  // should be {1 + 3, 2, 55, 55}
95
96     size_t ind = 0;
97     ASSERT_NO_THROW(ind = network.addLayer(layer));
98     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
99 }
100
101 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithUnalignedPorts3) {
102     Builder::Network network("network");
103     Builder::ConcatLayer layer("concat layer");
104
105     layer.setAxis(0);
106     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 55, 55})});
107     layer.setOutputPort(Port({100, 2, 55, 55}));  // should be {1 + 3, 2, 55, 55}
108
109     size_t ind = 0;
110     ASSERT_NO_THROW(ind = network.addLayer(layer));
111     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
112 }
113
114 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithUnalignedPorts4) {
115     Builder::Network network("network");
116     Builder::ConcatLayer layer("concat layer");
117
118     layer.setAxis(1);
119     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 55, 55})});
120     layer.setOutputPort(Port({1, 100, 55, 55}));  // should be {1, 2 + 4, 55, 55}
121
122     size_t ind = 0;
123     ASSERT_NO_THROW(ind = network.addLayer(layer));
124     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
125 }
126
127 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithDifferentInputPorts1) {
128     Builder::Network network("network");
129     Builder::ConcatLayer layer("concat layer");
130
131     layer.setAxis(0);
132     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 55, 155})});  // here
133     layer.setOutputPort(Port({1 + 3, 4, 55, 55}));
134
135     size_t ind = 0;
136     ASSERT_NO_THROW(ind = network.addLayer(layer));
137     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
138 }
139
140 TEST_F(ConcatLayerBuilderTest, cannotCreateLayerWithDifferentInputPorts2) {
141     Builder::Network network("network");
142     Builder::ConcatLayer layer("concat layer");
143
144     layer.setAxis(0);
145     layer.setInputPorts({Port({1, 2, 55, 55}), Port({3, 2, 155, 55})});  // here
146     layer.setOutputPort(Port({1 + 3, 4, 55, 55}));
147
148     size_t ind = 0;
149     ASSERT_NO_THROW(ind = network.addLayer(layer));
150     ASSERT_THROW(network.getLayer(ind)->validate(false), InferenceEngine::details::InferenceEngineException);
151 }