Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / transformations / sub_test.cpp
1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <string.h>
7 #include <transform/transform_network.hpp>
8 #include <transform/transformations/sub.hpp>
9 #include <ie_builders.hpp>
10
11 #include "tranformations_test.hpp"
12
13 using namespace testing;
14 using namespace InferenceEngine;
15
16 class TransformNetworkTest: public TransformationTestCommon {};
17
18 TEST_F(TransformationTestCommon, Sub) {
19     Builder::Network builder("sub");
20
21     idx_t firstInputId = builder.addLayer(Builder::InputLayer("FirstInput").setPort(Port({1,3, 227, 227})));
22     idx_t secondInputId = builder.addLayer(Builder::InputLayer("SecondInput").setPort(Port({1,3, 227, 227})));
23     idx_t eltwiseSubId = builder.addLayer({firstInputId, secondInputId}, Builder::EltwiseLayer("Sub").setEltwiseType(Builder::EltwiseLayer::EltwiseType::SUB));
24     idx_t clampId = builder.addLayer({eltwiseSubId}, Builder::ClampLayer("clamp"));
25     auto network = Transform::Network(builder);
26
27     Transform::TransformationSub transformationSub;
28     transformationSub.execute(network);
29     ASSERT_THROW(network.getLayer("Sub"), InferenceEngine::details::InferenceEngineException);
30     auto sumLayer = network.getLayer(firstInputId).getOutPort().getConnection().getDestination().getLayer();
31     auto powerLayer = network.getLayer(secondInputId).getOutPort().getConnection().getDestination().getLayer();
32     ASSERT_EQ(sumLayer.getType(), "Eltwise");
33     ASSERT_EQ(sumLayer.getParameter("operation").as<std::string>(), "sum");
34     ASSERT_EQ(powerLayer.getType(), "Power");
35     ASSERT_EQ(powerLayer.getParameter("power").as<float>(), 1.0f);
36     ASSERT_EQ(powerLayer.getParameter("scale").as<float>(), -1.0f);
37     ASSERT_EQ(powerLayer.getParameter("shift").as<float>(), 0.0f);
38     ASSERT_EQ(sumLayer.getOutPort().getConnection().getDestination().getLayer().getId(), clampId);
39 }