Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseActivationFunctionPass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseActivationFunctionPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <luci/test/TestIOGraph.h>
22
23 #include <gtest/gtest.h>
24
25 namespace
26 {
27
28 using namespace luci::test;
29
30 /**
31  *  Simple graph for test
32  *
33  *  BEFORE
34  *
35  *         [Conv1]
36  *           |
37  *     [Activation func]
38  *           |
39  *         [Conv2]
40  *
41  *  AFTER
42  *
43  *   [Conv1 + Activation func]
44  *           |
45  *         [Conv2]
46  *
47  */
48 class ConvReluConvGraphlet
49 {
50 public:
51   ConvReluConvGraphlet() = default;
52
53   void init(loco::Graph *g)
54   {
55     _conv1 = g->nodes()->create<luci::CircleConv2D>();
56     _conv2 = g->nodes()->create<luci::CircleConv2D>();
57     _relu = g->nodes()->create<luci::CircleRelu>();
58     _conv1_f = g->nodes()->create<luci::CircleConst>();
59     _conv1_b = g->nodes()->create<luci::CircleConst>();
60     _conv2_f = g->nodes()->create<luci::CircleConst>();
61     _conv2_b = g->nodes()->create<luci::CircleConst>();
62
63     _conv1->fusedActivationFunction(luci::FusedActFunc::NONE);
64
65     _conv1->name("conv1");
66     _conv2->name("conv2");
67     _relu->name("relu");
68     _conv1_f->name("conv1f");
69     _conv1_b->name("conv1b");
70     _conv2_f->name("conv2f");
71     _conv2_b->name("conv2b");
72   }
73
74 public:
75   luci::CircleRelu *relu() { return _relu; }
76   luci::CircleConv2D *conv1() { return _conv1; }
77   luci::CircleConv2D *conv2() { return _conv2; }
78
79 protected:
80   luci::CircleConv2D *_conv1 = nullptr;
81   luci::CircleConv2D *_conv2 = nullptr;
82   luci::CircleRelu *_relu = nullptr;
83   luci::CircleConst *_conv1_f = nullptr;
84   luci::CircleConst *_conv1_b = nullptr;
85   luci::CircleConst *_conv2_f = nullptr;
86   luci::CircleConst *_conv2_b = nullptr;
87 };
88
89 class FuseActTestGraph : public TestIOGraph, public ConvReluConvGraphlet
90 {
91 public:
92   FuseActTestGraph() = default;
93
94   void init(void)
95   {
96     TestIOGraph::init({1}, {1});
97     ConvReluConvGraphlet::init(g());
98
99     _conv1->input(input());
100     _conv1->filter(_conv1_f);
101     _conv1->bias(_conv1_b);
102
103     _relu->features(_conv1);
104
105     _conv2->input(_relu);
106     _conv2->filter(_conv2_f);
107     _conv2->bias(_conv2_b);
108
109     output()->from(_conv2);
110   }
111 };
112
113 class ConvHasMultiSuccGraph : public TestIOGraph, public ConvReluConvGraphlet
114 {
115 public:
116   ConvHasMultiSuccGraph() = default;
117
118   void init(void)
119   {
120     TestIOGraph::init({1}, {1});
121     ConvReluConvGraphlet::init(g());
122
123     _conv1->input(input());
124     _conv1->filter(_conv1_f);
125     _conv1->bias(_conv1_b);
126
127     _relu->features(_conv1);
128
129     _conv2->input(_conv1);
130     _conv2->filter(_conv2_f);
131     _conv2->bias(_conv2_b);
132
133     output()->from(_relu); // We need to check from relu
134   }
135 };
136
137 // TODO use ::testing::Test
138
139 } // namespace
140
141 TEST(FuseActivationFunctionPassTest, name)
142 {
143   luci::FuseActivationFunctionPass pass;
144   auto const name = pass.name();
145   ASSERT_NE(nullptr, name);
146 }
147
148 TEST(FusePreActivationBatchNorm, fuse_activation_function)
149 {
150   FuseActTestGraph g;
151   luci::FuseActivationFunctionPass pass;
152
153   g.init();
154
155   EXPECT_TRUE(pass.run(g.g()));
156   EXPECT_EQ(g.conv1(), g.conv2()->input());
157 }
158
159 TEST(FusePreActivationBatchNorm, fuse_activation_function_dup_relu)
160 {
161   FuseActTestGraph g;
162   luci::FuseActivationFunctionPass pass;
163
164   g.init();
165   g.conv1()->fusedActivationFunction(luci::FusedActFunc::RELU);
166
167   EXPECT_TRUE(pass.run(g.g()));
168   EXPECT_EQ(g.conv1(), g.conv2()->input());
169 }
170
171 TEST(FusePreActivationBatchNorm, fuse_activation_function_mulsucc_NEG)
172 {
173   ConvHasMultiSuccGraph g;
174   luci::FuseActivationFunctionPass pass;
175
176   g.init();
177
178   // Relu input Conv2D has multiple successors
179   EXPECT_FALSE(pass.run(g.g()));
180 }
181
182 TEST(FusePreActivationBatchNorm, fuse_activation_function_tanh_NEG)
183 {
184   FuseActTestGraph g;
185   luci::FuseActivationFunctionPass pass;
186
187   g.init();
188   g.conv1()->fusedActivationFunction(luci::FusedActFunc::TANH);
189
190   // Relu input Conv2D already has activation function
191   EXPECT_FALSE(pass.run(g.g()));
192 }