Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / builders / split_layer_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <string.h>
7 #include <ie_builders.hpp>
8
9 #include "builder_test.hpp"
10
11 using namespace testing;
12 using namespace InferenceEngine;
13
14 class SplitLayerBuilderTest : public BuilderTestCommon {};
15
16 TEST_F(SplitLayerBuilderTest, CreateIdentitySplitLayer) {
17     Builder::Network builder("network");
18     SizeVector shape = {1, 4, 3, 4};
19     idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
20     layerId = builder.addLayer({layerId}, Builder::SplitLayer("identity").setOutputPorts({Port()}));
21     builder.addLayer({layerId}, Builder::OutputLayer("output"));
22
23     const auto network = builder.build();
24     ASSERT_EQ(shape, network->getLayer(layerId)->getOutputPorts()[0].shape());
25 }
26
27 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputs) {
28     Builder::Network builder("network");
29     SizeVector shape = {1, 4, 3, 4};
30     SizeVector outShape = {1, 2, 3, 4};
31     idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
32     layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setOutputPorts({Port(), Port()}));
33     builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
34     builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
35
36     const auto network = builder.build();
37     ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[0].shape());
38     ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[1].shape());
39 }
40
41 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputsAndOneInitialized) {
42     Builder::Network builder("network");
43     SizeVector shape = {1, 4, 3, 4};
44     SizeVector outShape1 = {1, 3, 3, 4};
45     SizeVector outShape2 = {1, 1, 3, 4};
46     idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
47     layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setOutputPorts({Port(outShape1), Port()}));
48     builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
49     builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
50
51     const auto network = builder.build();
52     ASSERT_EQ(outShape1, network->getLayer(layerId)->getOutputPorts()[0].shape());
53     ASSERT_EQ(outShape2, network->getLayer(layerId)->getOutputPorts()[1].shape());
54 }
55
56 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputsAxis3) {
57     Builder::Network builder("network");
58     SizeVector shape = {1, 4, 3, 4};
59     SizeVector outShape = {1, 4, 3, 2};
60     idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
61     layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setAxis(3).setOutputPorts({Port(), Port()}));
62     builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
63     builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
64
65     const auto network = builder.build();
66     ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[0].shape());
67     ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[1].shape());
68 }
69
70 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputsAxis3AndOneInitialized) {
71     Builder::Network builder("network");
72     SizeVector shape = {1, 4, 3, 4};
73     SizeVector outShape1 = {1, 4, 3, 1};
74     SizeVector outShape2 = {1, 4, 3, 3};
75     idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
76     layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setAxis(3).setOutputPorts({Port(outShape1), Port()}));
77     builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
78     builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
79
80     const auto network = builder.build();
81     ASSERT_EQ(outShape1, network->getLayer(layerId)->getOutputPorts()[0].shape());
82     ASSERT_EQ(outShape2, network->getLayer(layerId)->getOutputPorts()[1].shape());
83 }