1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
7 #include <transform/transform_network.hpp>
8 #include <transform/transformations/eltwise_broadcast.hpp>
9 #include <ie_builders.hpp>
11 #include "tranformations_test.hpp"
13 using namespace testing;
14 using namespace InferenceEngine;
16 class TransformNetworkTest: public TransformationTestCommon {};
18 TEST_F(TransformationTestCommon, EltwiseBroadcastOneDimension) {
19 Builder::Network builder("eltwiseBroadcast");
21 idx_t firstInputId = builder.addLayer(Builder::InputLayer("FirstInput").setPort(Port({1, 3, 227, 1})));
22 idx_t secondInputId = builder.addLayer(Builder::InputLayer("SecondInput").setPort(Port({1, 3, 227, 227})));
23 idx_t eltwiseSumId = builder.addLayer({firstInputId, secondInputId}, Builder::EltwiseLayer("Sum").
24 setEltwiseType(Builder::EltwiseLayer::EltwiseType::SUM).
25 setOutputPort(Port({1, 3, 227, 227})));
26 auto network = Transform::Network(builder);
28 Transform::TransformationEltwiseBroadcast transformationEltwiseBroadcast;
29 transformationEltwiseBroadcast.execute(network);
30 auto firstInputLayer = network.getLayer(firstInputId);
31 auto tileLayer = network.getLayer(firstInputId).getOutPort().getConnection().getDestination().getLayer();
32 ASSERT_EQ(tileLayer.getType(), "Tile");
33 ASSERT_EQ(tileLayer.getParameter("axis").as<size_t>(), 3);
34 ASSERT_EQ(tileLayer.getParameter("tiles").as<size_t>(), 227);
35 ASSERT_EQ(firstInputLayer.getOutPort().getConnection().getDestination().getLayer().getId(), tileLayer.getId());
36 ASSERT_EQ(tileLayer.getOutPort().getConnection().getDestination().getLayer().getId(), eltwiseSumId);
39 TEST_F(TransformationTestCommon, EltwiseBroadcastTwoDimensions) {
40 Builder::Network builder("eltwiseBroadcast");
42 idx_t firstInputId = builder.addLayer(Builder::InputLayer("FirstInput").setPort(Port({1, 1, 227, 1})));
43 idx_t secondInputId = builder.addLayer(Builder::InputLayer("SecondInput").setPort(Port({1, 3, 227, 227})));
44 idx_t eltwiseSumId = builder.addLayer({firstInputId, secondInputId}, Builder::EltwiseLayer("Sum").
45 setEltwiseType(Builder::EltwiseLayer::EltwiseType::SUM).
46 setOutputPort(Port({1, 3, 227, 227})));
47 auto network = Transform::Network(builder);
49 Transform::TransformationEltwiseBroadcast transformationEltwiseBroadcast;
50 transformationEltwiseBroadcast.execute(network);
51 auto firstInputLayer = network.getLayer(firstInputId);
52 auto tile1Layer = network.getLayer(firstInputId).getOutPort().getConnection().getDestination().getLayer();
53 auto tile2Layer = tile1Layer.getOutPort().getConnection().getDestination().getLayer();
54 ASSERT_EQ(tile1Layer.getType(), "Tile");
55 ASSERT_EQ(tile1Layer.getParameter("axis").as<size_t>(), 1);
56 ASSERT_EQ(tile1Layer.getParameter("tiles").as<size_t>(), 3);
57 ASSERT_EQ(tile2Layer.getType(), "Tile");
58 ASSERT_EQ(tile2Layer.getParameter("axis").as<size_t>(), 3);
59 ASSERT_EQ(tile2Layer.getParameter("tiles").as<size_t>(), 227);
60 ASSERT_EQ(firstInputLayer.getOutPort().getConnection().getDestination().getLayer().getId(), tile1Layer.getId());
61 ASSERT_EQ(tile1Layer.getOutPort().getConnection().getDestination().getLayer().getId(), tile2Layer.getId());
62 ASSERT_EQ(tile2Layer.getOutPort().getConnection().getDestination().getLayer().getId(), eltwiseSumId);