Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / transformations / eltwise_broadcast_test.cpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <string.h>
7 #include <transform/transform_network.hpp>
8 #include <transform/transformations/eltwise_broadcast.hpp>
9 #include <ie_builders.hpp>
10
11 #include "tranformations_test.hpp"
12
13 using namespace testing;
14 using namespace InferenceEngine;
15
16 class TransformNetworkTest: public TransformationTestCommon {};
17
18 TEST_F(TransformationTestCommon, EltwiseBroadcastOneDimension) {
19     Builder::Network builder("eltwiseBroadcast");
20
21     idx_t firstInputId = builder.addLayer(Builder::InputLayer("FirstInput").setPort(Port({1, 3, 227, 1})));
22     idx_t secondInputId = builder.addLayer(Builder::InputLayer("SecondInput").setPort(Port({1, 3, 227, 227})));
23     idx_t eltwiseSumId = builder.addLayer({firstInputId, secondInputId}, Builder::EltwiseLayer("Sum").
24                                                                          setEltwiseType(Builder::EltwiseLayer::EltwiseType::SUM).
25                                                                          setOutputPort(Port({1, 3, 227, 227})));
26     auto network = Transform::Network(builder);
27
28     Transform::TransformationEltwiseBroadcast transformationEltwiseBroadcast;
29     transformationEltwiseBroadcast.execute(network);
30     auto firstInputLayer = network.getLayer(firstInputId);
31     auto tileLayer = network.getLayer(firstInputId).getOutPort().getConnection().getDestination().getLayer();
32     ASSERT_EQ(tileLayer.getType(), "Tile");
33     ASSERT_EQ(tileLayer.getParameter("axis").as<size_t>(), 3);
34     ASSERT_EQ(tileLayer.getParameter("tiles").as<size_t>(), 227);
35     ASSERT_EQ(firstInputLayer.getOutPort().getConnection().getDestination().getLayer().getId(), tileLayer.getId());
36     ASSERT_EQ(tileLayer.getOutPort().getConnection().getDestination().getLayer().getId(), eltwiseSumId);
37 }
38
39 TEST_F(TransformationTestCommon, EltwiseBroadcastTwoDimensions) {
40     Builder::Network builder("eltwiseBroadcast");
41
42     idx_t firstInputId = builder.addLayer(Builder::InputLayer("FirstInput").setPort(Port({1, 1, 227, 1})));
43     idx_t secondInputId = builder.addLayer(Builder::InputLayer("SecondInput").setPort(Port({1, 3, 227, 227})));
44     idx_t eltwiseSumId = builder.addLayer({firstInputId, secondInputId}, Builder::EltwiseLayer("Sum").
45                                                                          setEltwiseType(Builder::EltwiseLayer::EltwiseType::SUM).
46                                                                          setOutputPort(Port({1, 3, 227, 227})));
47     auto network = Transform::Network(builder);
48
49     Transform::TransformationEltwiseBroadcast transformationEltwiseBroadcast;
50     transformationEltwiseBroadcast.execute(network);
51     auto firstInputLayer = network.getLayer(firstInputId);
52     auto tile1Layer = network.getLayer(firstInputId).getOutPort().getConnection().getDestination().getLayer();
53     auto tile2Layer = tile1Layer.getOutPort().getConnection().getDestination().getLayer();
54     ASSERT_EQ(tile1Layer.getType(), "Tile");
55     ASSERT_EQ(tile1Layer.getParameter("axis").as<size_t>(), 1);
56     ASSERT_EQ(tile1Layer.getParameter("tiles").as<size_t>(), 3);
57     ASSERT_EQ(tile2Layer.getType(), "Tile");
58     ASSERT_EQ(tile2Layer.getParameter("axis").as<size_t>(), 3);
59     ASSERT_EQ(tile2Layer.getParameter("tiles").as<size_t>(), 227);
60     ASSERT_EQ(firstInputLayer.getOutPort().getConnection().getDestination().getLayer().getId(), tile1Layer.getId());
61     ASSERT_EQ(tile1Layer.getOutPort().getConnection().getDestination().getLayer().getId(), tile2Layer.getId());
62     ASSERT_EQ(tile2Layer.getOutPort().getConnection().getDestination().getLayer().getId(), eltwiseSumId);
63 }