788353cd8ee9e32879120d99dc27e3862029ff18
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / QuantizePreCheckerPass.test.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/QuantizePreCheckerPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <gtest/gtest.h>
22
23 class SimpleConv2DGraph
24 {
25 public:
26   SimpleConv2DGraph(bool make_valid)
27   {
28     conv2d_node = g.nodes()->create<luci::CircleConv2D>();
29     input_1 = g.nodes()->create<luci::CircleInput>();
30     filter = g.nodes()->create<luci::CircleConst>();
31
32     conv2d_node->input(input_1);
33     conv2d_node->filter(filter);
34
35     if (make_valid)
36     {
37       bias = g.nodes()->create<luci::CircleConst>();
38       conv2d_node->bias(bias);
39     }
40     else
41     {
42       input_2 = g.nodes()->create<luci::CircleInput>();
43       conv2d_node->bias(input_2);
44     }
45
46     output = g.nodes()->create<luci::CircleOutput>();
47
48     auto graph_output = g.outputs()->create();
49     output->index(graph_output->index());
50
51     output->from(conv2d_node);
52   }
53
54 public:
55   loco::Graph g;
56
57 private:
58   luci::CircleConv2D *conv2d_node = nullptr;
59   luci::CircleInput *input_1 = nullptr;
60   luci::CircleInput *input_2 = nullptr;
61   luci::CircleConst *filter = nullptr;
62   luci::CircleConst *bias = nullptr;
63   luci::CircleOutput *output = nullptr;
64 };
65
66 class SimpleDepthConv2DGraph
67 {
68 public:
69   SimpleDepthConv2DGraph(bool make_valid)
70   {
71     depth_conv2d_node = g.nodes()->create<luci::CircleDepthwiseConv2D>();
72     input_1 = g.nodes()->create<luci::CircleInput>();
73     filter = g.nodes()->create<luci::CircleConst>();
74
75     depth_conv2d_node->input(input_1);
76     depth_conv2d_node->filter(filter);
77
78     if (make_valid)
79     {
80       bias = g.nodes()->create<luci::CircleConst>();
81       depth_conv2d_node->bias(bias);
82     }
83     else
84     {
85       input_2 = g.nodes()->create<luci::CircleInput>();
86       depth_conv2d_node->bias(input_2);
87     }
88
89     output = g.nodes()->create<luci::CircleOutput>();
90
91     auto graph_output = g.outputs()->create();
92     output->index(graph_output->index());
93
94     output->from(depth_conv2d_node);
95   }
96
97 public:
98   loco::Graph g;
99
100 private:
101   luci::CircleDepthwiseConv2D *depth_conv2d_node = nullptr;
102   luci::CircleInput *input_1 = nullptr;
103   luci::CircleInput *input_2 = nullptr;
104   luci::CircleConst *filter = nullptr;
105   luci::CircleConst *bias = nullptr;
106   luci::CircleOutput *output = nullptr;
107 };
108
109 class SimpleFCGraph
110 {
111 public:
112   SimpleFCGraph(bool make_valid)
113   {
114     fc_node = g.nodes()->create<luci::CircleFullyConnected>();
115     input_1 = g.nodes()->create<luci::CircleInput>();
116     weights = g.nodes()->create<luci::CircleConst>();
117
118     fc_node->input(input_1);
119     fc_node->weights(weights);
120
121     if (make_valid)
122     {
123       bias = g.nodes()->create<luci::CircleConst>();
124       fc_node->bias(bias);
125     }
126     else
127     {
128       input_2 = g.nodes()->create<luci::CircleInput>();
129       fc_node->bias(input_2);
130     }
131
132     output = g.nodes()->create<luci::CircleOutput>();
133
134     auto graph_output = g.outputs()->create();
135     output->index(graph_output->index());
136
137     output->from(fc_node);
138   }
139
140 public:
141   loco::Graph g;
142
143 private:
144   luci::CircleFullyConnected *fc_node = nullptr;
145   luci::CircleInput *input_1 = nullptr;
146   luci::CircleInput *input_2 = nullptr;
147   luci::CircleConst *weights = nullptr;
148   luci::CircleConst *bias = nullptr;
149   luci::CircleOutput *output = nullptr;
150 };
151
152 class SimpleInstanceNormGraph
153 {
154 public:
155   SimpleInstanceNormGraph(bool make_valid)
156   {
157     instance_norm_node = g.nodes()->create<luci::CircleInstanceNorm>();
158     input_1 = g.nodes()->create<luci::CircleInput>();
159     gamma = g.nodes()->create<luci::CircleConst>();
160
161     instance_norm_node->input(input_1);
162     instance_norm_node->gamma(gamma);
163
164     if (make_valid)
165     {
166       beta = g.nodes()->create<luci::CircleConst>();
167       instance_norm_node->beta(beta);
168     }
169     else
170     {
171       input_2 = g.nodes()->create<luci::CircleInput>();
172       instance_norm_node->beta(input_2);
173     }
174
175     output = g.nodes()->create<luci::CircleOutput>();
176
177     auto graph_output = g.outputs()->create();
178     output->index(graph_output->index());
179
180     output->from(instance_norm_node);
181   }
182
183 public:
184   loco::Graph g;
185
186 private:
187   luci::CircleInstanceNorm *instance_norm_node = nullptr;
188   luci::CircleInput *input_1 = nullptr;
189   luci::CircleInput *input_2 = nullptr;
190   luci::CircleConst *gamma = nullptr;
191   luci::CircleConst *beta = nullptr;
192   luci::CircleOutput *output = nullptr;
193 };
194
195 class SimpleTransposeConvGraph
196 {
197 public:
198   SimpleTransposeConvGraph(bool make_valid)
199   {
200     transpose_conv = g.nodes()->create<luci::CircleTransposeConv>();
201     input_1 = g.nodes()->create<luci::CircleInput>();
202
203     input_sizes = g.nodes()->create<luci::CircleConst>();
204     filter = g.nodes()->create<luci::CircleConst>();
205
206     transpose_conv->outBackprop(input_1);
207     transpose_conv->filter(filter);
208     transpose_conv->inputSizes(input_sizes);
209
210     if (make_valid)
211     {
212       bias = g.nodes()->create<luci::CircleConst>();
213       transpose_conv->bias(bias);
214     }
215     else
216     {
217       input_2 = g.nodes()->create<luci::CircleInput>();
218       transpose_conv->bias(input_2);
219     }
220
221     output = g.nodes()->create<luci::CircleOutput>();
222
223     auto graph_output = g.outputs()->create();
224     output->index(graph_output->index());
225
226     output->from(transpose_conv);
227   }
228
229 public:
230   loco::Graph g;
231
232 private:
233   luci::CircleTransposeConv *transpose_conv = nullptr;
234   luci::CircleInput *input_1 = nullptr;
235   luci::CircleInput *input_2 = nullptr;
236   luci::CircleConst *input_sizes = nullptr;
237   luci::CircleConst *filter = nullptr;
238   luci::CircleConst *bias = nullptr;
239   luci::CircleOutput *output = nullptr;
240 };
241
242 class SimplePReluGraph
243 {
244 public:
245   SimplePReluGraph(bool make_valid)
246   {
247     prelu = g.nodes()->create<luci::CirclePRelu>();
248     input_1 = g.nodes()->create<luci::CircleInput>();
249
250     prelu->input(input_1);
251
252     if (make_valid)
253     {
254       alpha = g.nodes()->create<luci::CircleConst>();
255       prelu->alpha(alpha);
256     }
257     else
258     {
259       input_2 = g.nodes()->create<luci::CircleInput>();
260       prelu->alpha(input_2);
261     }
262
263     output = g.nodes()->create<luci::CircleOutput>();
264
265     auto graph_output = g.outputs()->create();
266     output->index(graph_output->index());
267
268     output->from(prelu);
269   }
270
271 public:
272   loco::Graph g;
273
274 private:
275   luci::CirclePRelu *prelu = nullptr;
276   luci::CircleInput *input_1 = nullptr;
277   luci::CircleInput *input_2 = nullptr;
278   luci::CircleConst *alpha = nullptr;
279   luci::CircleOutput *output = nullptr;
280 };
281
282 TEST(QuantizePreCheckerPassTest, name)
283 {
284   luci::QuantizePreCheckerPass pass{};
285   auto const name = pass.name();
286   ASSERT_NE(nullptr, name);
287 }
288
289 // Test Conv2d
290 TEST(QuantizePreCheckerPassTest, conv2d)
291 {
292   SimpleConv2DGraph valid_graph(true);
293
294   luci::QuantizePreCheckerPass checker{};
295
296   EXPECT_NO_THROW(checker.run(&valid_graph.g));
297 }
298
299 TEST(QuantizePreCheckerPassTest, conv2d_NEG)
300 {
301   SimpleConv2DGraph invalid_graph(false);
302
303   luci::QuantizePreCheckerPass checker{};
304
305   EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
306 }
307
308 // Test DepthwiseConv2d
309 TEST(QuantizePreCheckerPassTest, depthwise_conv2d)
310 {
311   SimpleDepthConv2DGraph valid_graph(true);
312
313   luci::QuantizePreCheckerPass checker{};
314
315   EXPECT_NO_THROW(checker.run(&valid_graph.g));
316 }
317
318 TEST(QuantizePreCheckerPassTest, depthwise_conv2d_NEG)
319 {
320   SimpleDepthConv2DGraph invalid_graph(false);
321
322   luci::QuantizePreCheckerPass checker{};
323
324   EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
325 }
326
327 // Test FullyConnected
328 TEST(QuantizePreCheckerPassTest, fully_connected)
329 {
330   SimpleFCGraph valid_graph(true);
331
332   luci::QuantizePreCheckerPass checker{};
333
334   EXPECT_NO_THROW(checker.run(&valid_graph.g));
335 }
336
337 TEST(QuantizePreCheckerPassTest, fully_connected_NEG)
338 {
339   SimpleFCGraph invalid_graph(false);
340
341   luci::QuantizePreCheckerPass checker{};
342
343   EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
344 }
345
346 // Test InstanceNorm
347 TEST(QuantizePreCheckerPassTest, instance_norm)
348 {
349   SimpleInstanceNormGraph valid_graph(true);
350
351   luci::QuantizePreCheckerPass checker{};
352
353   EXPECT_NO_THROW(checker.run(&valid_graph.g));
354 }
355
356 TEST(QuantizePreCheckerPassTest, instance_norm_NEG)
357 {
358   SimpleInstanceNormGraph invalid_graph(false);
359
360   luci::QuantizePreCheckerPass checker{};
361
362   EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
363 }
364
365 // Test TransposeConv
366 TEST(QuantizePreCheckerPassTest, transpose_conv)
367 {
368   SimpleTransposeConvGraph valid_graph(true);
369
370   luci::QuantizePreCheckerPass checker{};
371
372   EXPECT_NO_THROW(checker.run(&valid_graph.g));
373 }
374
375 TEST(QuantizePreCheckerPassTest, transpose_conv_NEG)
376 {
377   SimpleTransposeConvGraph invalid_graph(false);
378
379   luci::QuantizePreCheckerPass checker{};
380
381   EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
382 }
383
384 // Test PRelu
385 TEST(QuantizePreCheckerPassTest, prelu)
386 {
387   SimplePReluGraph valid_graph(true);
388
389   luci::QuantizePreCheckerPass checker{};
390
391   EXPECT_NO_THROW(checker.run(&valid_graph.g));
392 }
393
394 TEST(QuantizePreCheckerPassTest, prelu_NEG)
395 {
396   SimplePReluGraph invalid_graph(false);
397
398   luci::QuantizePreCheckerPass checker{};
399
400   EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
401 }