1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
7 #include <ie_builders.hpp>
9 #include "builder_test.hpp"
11 using namespace testing;
12 using namespace InferenceEngine;
14 class SplitLayerBuilderTest : public BuilderTestCommon {};
16 TEST_F(SplitLayerBuilderTest, CreateIdentitySplitLayer) {
17 Builder::Network builder("network");
18 SizeVector shape = {1, 4, 3, 4};
19 idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
20 layerId = builder.addLayer({layerId}, Builder::SplitLayer("identity").setOutputPorts({Port()}));
21 builder.addLayer({layerId}, Builder::OutputLayer("output"));
23 const auto network = builder.build();
24 ASSERT_EQ(shape, network->getLayer(layerId)->getOutputPorts()[0].shape());
27 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputs) {
28 Builder::Network builder("network");
29 SizeVector shape = {1, 4, 3, 4};
30 SizeVector outShape = {1, 2, 3, 4};
31 idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
32 layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setOutputPorts({Port(), Port()}));
33 builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
34 builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
36 const auto network = builder.build();
37 ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[0].shape());
38 ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[1].shape());
41 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputsAndOneInitialized) {
42 Builder::Network builder("network");
43 SizeVector shape = {1, 4, 3, 4};
44 SizeVector outShape1 = {1, 3, 3, 4};
45 SizeVector outShape2 = {1, 1, 3, 4};
46 idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
47 layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setOutputPorts({Port(outShape1), Port()}));
48 builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
49 builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
51 const auto network = builder.build();
52 ASSERT_EQ(outShape1, network->getLayer(layerId)->getOutputPorts()[0].shape());
53 ASSERT_EQ(outShape2, network->getLayer(layerId)->getOutputPorts()[1].shape());
56 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputsAxis3) {
57 Builder::Network builder("network");
58 SizeVector shape = {1, 4, 3, 4};
59 SizeVector outShape = {1, 4, 3, 2};
60 idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
61 layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setAxis(3).setOutputPorts({Port(), Port()}));
62 builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
63 builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
65 const auto network = builder.build();
66 ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[0].shape());
67 ASSERT_EQ(outShape, network->getLayer(layerId)->getOutputPorts()[1].shape());
70 TEST_F(SplitLayerBuilderTest, CreateSplitLayerWithTwoOutputsAxis3AndOneInitialized) {
71 Builder::Network builder("network");
72 SizeVector shape = {1, 4, 3, 4};
73 SizeVector outShape1 = {1, 4, 3, 1};
74 SizeVector outShape2 = {1, 4, 3, 3};
75 idx_t layerId = builder.addLayer(Builder::InputLayer("input").setPort(Port(shape, Precision::FP16)));
76 layerId = builder.addLayer({layerId}, Builder::SplitLayer("split").setAxis(3).setOutputPorts({Port(outShape1), Port()}));
77 builder.addLayer({{layerId}}, Builder::OutputLayer("output1"));
78 builder.addLayer({{layerId, 1}}, Builder::OutputLayer("output2"));
80 const auto network = builder.build();
81 ASSERT_EQ(outShape1, network->getLayer(layerId)->getOutputPorts()[0].shape());
82 ASSERT_EQ(outShape2, network->getLayer(layerId)->getOutputPorts()[1].shape());