Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FuseActivationFunctionPass.test.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FuseActivationFunctionPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <luci/test/TestIOGraph.h>
22
23 #include <gtest/gtest.h>
24
25 namespace
26 {
27
28 using namespace luci::test;
29
30 /**
31  *  Simple graph for test
32  *
33  *  BEFORE
34  *
35  *         [Conv1]
36  *           |
37  *     [Activation func]
38  *           |
39  *         [Conv2]
40  *
41  *  AFTER
42  *
43  *   [Conv1 + Activation func]
44  *           |
45  *         [Conv2]
46  *
47  */
48 class ConvReluConvGraphlet
49 {
50 public:
51   ConvReluConvGraphlet() = default;
52
53   void init(loco::Graph *g)
54   {
55     _conv1 = g->nodes()->create<luci::CircleConv2D>();
56     _conv2 = g->nodes()->create<luci::CircleConv2D>();
57     _relu = g->nodes()->create<luci::CircleRelu>();
58     _conv1_f = g->nodes()->create<luci::CircleConst>();
59     _conv1_b = g->nodes()->create<luci::CircleConst>();
60     _conv2_f = g->nodes()->create<luci::CircleConst>();
61     _conv2_b = g->nodes()->create<luci::CircleConst>();
62
63     _conv1->fusedActivationFunction(luci::FusedActFunc::NONE);
64
65     _conv1->name("conv1");
66     _conv2->name("conv2");
67     _relu->name("relu");
68     _conv1_f->name("conv1f");
69     _conv1_b->name("conv1b");
70     _conv2_f->name("conv2f");
71     _conv2_b->name("conv2b");
72   }
73
74 public:
75   luci::CircleRelu *relu() { return _relu; }
76   luci::CircleConv2D *conv1() { return _conv1; }
77   luci::CircleConv2D *conv2() { return _conv2; }
78
79 protected:
80   luci::CircleConv2D *_conv1 = nullptr;
81   luci::CircleConv2D *_conv2 = nullptr;
82   luci::CircleRelu *_relu = nullptr;
83   luci::CircleConst *_conv1_f = nullptr;
84   luci::CircleConst *_conv1_b = nullptr;
85   luci::CircleConst *_conv2_f = nullptr;
86   luci::CircleConst *_conv2_b = nullptr;
87 };
88
89 class ConvTanhConvGraphlet
90 {
91 public:
92   ConvTanhConvGraphlet() = default;
93
94   void init(loco::Graph *g)
95   {
96     _conv1 = g->nodes()->create<luci::CircleConv2D>();
97     _conv2 = g->nodes()->create<luci::CircleConv2D>();
98     _tanh = g->nodes()->create<luci::CircleTanh>();
99     _conv1_f = g->nodes()->create<luci::CircleConst>();
100     _conv1_b = g->nodes()->create<luci::CircleConst>();
101     _conv2_f = g->nodes()->create<luci::CircleConst>();
102     _conv2_b = g->nodes()->create<luci::CircleConst>();
103
104     _conv1->fusedActivationFunction(luci::FusedActFunc::NONE);
105
106     _conv1->name("conv1");
107     _conv2->name("conv2");
108     _tanh->name("tanh");
109     _conv1_f->name("conv1f");
110     _conv1_b->name("conv1b");
111     _conv2_f->name("conv2f");
112     _conv2_b->name("conv2b");
113   }
114
115 public:
116   luci::CircleTanh *tanh() { return _tanh; }
117   luci::CircleConv2D *conv1() { return _conv1; }
118   luci::CircleConv2D *conv2() { return _conv2; }
119
120 protected:
121   luci::CircleConv2D *_conv1 = nullptr;
122   luci::CircleConv2D *_conv2 = nullptr;
123   luci::CircleTanh *_tanh = nullptr;
124   luci::CircleConst *_conv1_f = nullptr;
125   luci::CircleConst *_conv1_b = nullptr;
126   luci::CircleConst *_conv2_f = nullptr;
127   luci::CircleConst *_conv2_b = nullptr;
128 };
129
130 class FuseActTestGraph : public TestIOGraph, public ConvReluConvGraphlet
131 {
132 public:
133   FuseActTestGraph() = default;
134
135   void init(void)
136   {
137     TestIOGraph::init({1}, {1});
138     ConvReluConvGraphlet::init(g());
139
140     _conv1->input(input());
141     _conv1->filter(_conv1_f);
142     _conv1->bias(_conv1_b);
143
144     _relu->features(_conv1);
145
146     _conv2->input(_relu);
147     _conv2->filter(_conv2_f);
148     _conv2->bias(_conv2_b);
149
150     output()->from(_conv2);
151   }
152 };
153
154 class FuseTanhActTestGraph : public TestIOGraph, public ConvTanhConvGraphlet
155 {
156 public:
157   FuseTanhActTestGraph() = default;
158
159   void init(void)
160   {
161     TestIOGraph::init({1}, {1});
162     ConvTanhConvGraphlet::init(g());
163
164     _conv1->input(input());
165     _conv1->filter(_conv1_f);
166     _conv1->bias(_conv1_b);
167
168     _tanh->x(_conv1);
169
170     _conv2->input(_tanh);
171     _conv2->filter(_conv2_f);
172     _conv2->bias(_conv2_b);
173
174     output()->from(_conv2);
175   }
176 };
177
178 class ConvHasMultiSuccGraph : public TestIOGraph, public ConvReluConvGraphlet
179 {
180 public:
181   ConvHasMultiSuccGraph() = default;
182
183   void init(void)
184   {
185     TestIOGraph::init({1}, {1});
186     ConvReluConvGraphlet::init(g());
187
188     _conv1->input(input());
189     _conv1->filter(_conv1_f);
190     _conv1->bias(_conv1_b);
191
192     _relu->features(_conv1);
193
194     _conv2->input(_conv1);
195     _conv2->filter(_conv2_f);
196     _conv2->bias(_conv2_b);
197
198     output()->from(_relu); // We need to check from relu
199   }
200 };
201
202 // TODO use ::testing::Test
203
204 } // namespace
205
206 TEST(FuseActivationFunctionPassTest, name)
207 {
208   luci::FuseActivationFunctionPass pass;
209   auto const name = pass.name();
210   ASSERT_NE(nullptr, name);
211 }
212
213 TEST(FusePreActivationBatchNorm, fuse_activation_function)
214 {
215   FuseActTestGraph g;
216   luci::FuseActivationFunctionPass pass;
217
218   g.init();
219
220   EXPECT_TRUE(pass.run(g.g()));
221   EXPECT_EQ(g.conv1(), g.conv2()->input());
222 }
223
224 TEST(FusePreActivationBatchNorm, fuse_activation_function_dup_relu)
225 {
226   FuseActTestGraph g;
227   luci::FuseActivationFunctionPass pass;
228
229   g.init();
230   g.conv1()->fusedActivationFunction(luci::FusedActFunc::RELU);
231
232   EXPECT_TRUE(pass.run(g.g()));
233   EXPECT_EQ(g.conv1(), g.conv2()->input());
234 }
235
236 TEST(FusePreActivationBatchNorm, fuse_activation_function_mulsucc_NEG)
237 {
238   ConvHasMultiSuccGraph g;
239   luci::FuseActivationFunctionPass pass;
240
241   g.init();
242
243   // Relu input Conv2D has multiple successors
244   EXPECT_FALSE(pass.run(g.g()));
245 }
246
247 TEST(FusePreActivationBatchNorm, fuse_activation_function_tanh_NEG)
248 {
249   FuseActTestGraph g;
250   luci::FuseActivationFunctionPass pass;
251
252   g.init();
253   g.conv1()->fusedActivationFunction(luci::FusedActFunc::TANH);
254
255   // Relu input Conv2D already has activation function
256   EXPECT_FALSE(pass.run(g.g()));
257 }
258
259 TEST(FusePreActivationBatchNorm, fuse_tanh_NEG)
260 {
261   FuseTanhActTestGraph g;
262   luci::FuseActivationFunctionPass pass;
263
264   g.init();
265
266   // Tanh should not be fused
267   // This can be changed when CONV+TANH is supported by luci-interpreter
268   EXPECT_FALSE(pass.run(g.g()));
269 }