2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/QuantizePreCheckerPass.h"
19 #include <luci/IR/CircleNodes.h>
21 #include <gtest/gtest.h>
23 class SimpleConv2DGraph
26 SimpleConv2DGraph(bool make_valid)
28 conv2d_node = g.nodes()->create<luci::CircleConv2D>();
29 input_1 = g.nodes()->create<luci::CircleInput>();
30 filter = g.nodes()->create<luci::CircleConst>();
32 conv2d_node->input(input_1);
33 conv2d_node->filter(filter);
37 bias = g.nodes()->create<luci::CircleConst>();
38 conv2d_node->bias(bias);
42 input_2 = g.nodes()->create<luci::CircleInput>();
43 conv2d_node->bias(input_2);
46 output = g.nodes()->create<luci::CircleOutput>();
48 auto graph_output = g.outputs()->create();
49 output->index(graph_output->index());
51 output->from(conv2d_node);
58 luci::CircleConv2D *conv2d_node = nullptr;
59 luci::CircleInput *input_1 = nullptr;
60 luci::CircleInput *input_2 = nullptr;
61 luci::CircleConst *filter = nullptr;
62 luci::CircleConst *bias = nullptr;
63 luci::CircleOutput *output = nullptr;
66 class SimpleDepthConv2DGraph
69 SimpleDepthConv2DGraph(bool make_valid)
71 depth_conv2d_node = g.nodes()->create<luci::CircleDepthwiseConv2D>();
72 input_1 = g.nodes()->create<luci::CircleInput>();
73 filter = g.nodes()->create<luci::CircleConst>();
75 depth_conv2d_node->input(input_1);
76 depth_conv2d_node->filter(filter);
80 bias = g.nodes()->create<luci::CircleConst>();
81 depth_conv2d_node->bias(bias);
85 input_2 = g.nodes()->create<luci::CircleInput>();
86 depth_conv2d_node->bias(input_2);
89 output = g.nodes()->create<luci::CircleOutput>();
91 auto graph_output = g.outputs()->create();
92 output->index(graph_output->index());
94 output->from(depth_conv2d_node);
101 luci::CircleDepthwiseConv2D *depth_conv2d_node = nullptr;
102 luci::CircleInput *input_1 = nullptr;
103 luci::CircleInput *input_2 = nullptr;
104 luci::CircleConst *filter = nullptr;
105 luci::CircleConst *bias = nullptr;
106 luci::CircleOutput *output = nullptr;
112 SimpleFCGraph(bool make_valid)
114 fc_node = g.nodes()->create<luci::CircleFullyConnected>();
115 input_1 = g.nodes()->create<luci::CircleInput>();
116 weights = g.nodes()->create<luci::CircleConst>();
118 fc_node->input(input_1);
119 fc_node->weights(weights);
123 bias = g.nodes()->create<luci::CircleConst>();
128 input_2 = g.nodes()->create<luci::CircleInput>();
129 fc_node->bias(input_2);
132 output = g.nodes()->create<luci::CircleOutput>();
134 auto graph_output = g.outputs()->create();
135 output->index(graph_output->index());
137 output->from(fc_node);
144 luci::CircleFullyConnected *fc_node = nullptr;
145 luci::CircleInput *input_1 = nullptr;
146 luci::CircleInput *input_2 = nullptr;
147 luci::CircleConst *weights = nullptr;
148 luci::CircleConst *bias = nullptr;
149 luci::CircleOutput *output = nullptr;
152 class SimpleInstanceNormGraph
155 SimpleInstanceNormGraph(bool make_valid)
157 instance_norm_node = g.nodes()->create<luci::CircleInstanceNorm>();
158 input_1 = g.nodes()->create<luci::CircleInput>();
159 gamma = g.nodes()->create<luci::CircleConst>();
161 instance_norm_node->input(input_1);
162 instance_norm_node->gamma(gamma);
166 beta = g.nodes()->create<luci::CircleConst>();
167 instance_norm_node->beta(beta);
171 input_2 = g.nodes()->create<luci::CircleInput>();
172 instance_norm_node->beta(input_2);
175 output = g.nodes()->create<luci::CircleOutput>();
177 auto graph_output = g.outputs()->create();
178 output->index(graph_output->index());
180 output->from(instance_norm_node);
187 luci::CircleInstanceNorm *instance_norm_node = nullptr;
188 luci::CircleInput *input_1 = nullptr;
189 luci::CircleInput *input_2 = nullptr;
190 luci::CircleConst *gamma = nullptr;
191 luci::CircleConst *beta = nullptr;
192 luci::CircleOutput *output = nullptr;
195 class SimpleTransposeConvGraph
198 SimpleTransposeConvGraph(bool make_valid)
200 transpose_conv = g.nodes()->create<luci::CircleTransposeConv>();
201 input_1 = g.nodes()->create<luci::CircleInput>();
203 input_sizes = g.nodes()->create<luci::CircleConst>();
204 filter = g.nodes()->create<luci::CircleConst>();
206 transpose_conv->outBackprop(input_1);
207 transpose_conv->filter(filter);
208 transpose_conv->inputSizes(input_sizes);
212 bias = g.nodes()->create<luci::CircleConst>();
213 transpose_conv->bias(bias);
217 input_2 = g.nodes()->create<luci::CircleInput>();
218 transpose_conv->bias(input_2);
221 output = g.nodes()->create<luci::CircleOutput>();
223 auto graph_output = g.outputs()->create();
224 output->index(graph_output->index());
226 output->from(transpose_conv);
233 luci::CircleTransposeConv *transpose_conv = nullptr;
234 luci::CircleInput *input_1 = nullptr;
235 luci::CircleInput *input_2 = nullptr;
236 luci::CircleConst *input_sizes = nullptr;
237 luci::CircleConst *filter = nullptr;
238 luci::CircleConst *bias = nullptr;
239 luci::CircleOutput *output = nullptr;
242 class SimplePReluGraph
245 SimplePReluGraph(bool make_valid)
247 prelu = g.nodes()->create<luci::CirclePRelu>();
248 input_1 = g.nodes()->create<luci::CircleInput>();
250 prelu->input(input_1);
254 alpha = g.nodes()->create<luci::CircleConst>();
259 input_2 = g.nodes()->create<luci::CircleInput>();
260 prelu->alpha(input_2);
263 output = g.nodes()->create<luci::CircleOutput>();
265 auto graph_output = g.outputs()->create();
266 output->index(graph_output->index());
275 luci::CirclePRelu *prelu = nullptr;
276 luci::CircleInput *input_1 = nullptr;
277 luci::CircleInput *input_2 = nullptr;
278 luci::CircleConst *alpha = nullptr;
279 luci::CircleOutput *output = nullptr;
282 TEST(QuantizePreCheckerPassTest, name)
284 luci::QuantizePreCheckerPass pass{};
285 auto const name = pass.name();
286 ASSERT_NE(nullptr, name);
290 TEST(QuantizePreCheckerPassTest, conv2d)
292 SimpleConv2DGraph valid_graph(true);
294 luci::QuantizePreCheckerPass checker{};
296 EXPECT_NO_THROW(checker.run(&valid_graph.g));
299 TEST(QuantizePreCheckerPassTest, conv2d_NEG)
301 SimpleConv2DGraph invalid_graph(false);
303 luci::QuantizePreCheckerPass checker{};
305 EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
308 // Test DepthwiseConv2d
309 TEST(QuantizePreCheckerPassTest, depthwise_conv2d)
311 SimpleDepthConv2DGraph valid_graph(true);
313 luci::QuantizePreCheckerPass checker{};
315 EXPECT_NO_THROW(checker.run(&valid_graph.g));
318 TEST(QuantizePreCheckerPassTest, depthwise_conv2d_NEG)
320 SimpleDepthConv2DGraph invalid_graph(false);
322 luci::QuantizePreCheckerPass checker{};
324 EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
327 // Test FullyConnected
328 TEST(QuantizePreCheckerPassTest, fully_connected)
330 SimpleFCGraph valid_graph(true);
332 luci::QuantizePreCheckerPass checker{};
334 EXPECT_NO_THROW(checker.run(&valid_graph.g));
337 TEST(QuantizePreCheckerPassTest, fully_connected_NEG)
339 SimpleFCGraph invalid_graph(false);
341 luci::QuantizePreCheckerPass checker{};
343 EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
347 TEST(QuantizePreCheckerPassTest, instance_norm)
349 SimpleInstanceNormGraph valid_graph(true);
351 luci::QuantizePreCheckerPass checker{};
353 EXPECT_NO_THROW(checker.run(&valid_graph.g));
356 TEST(QuantizePreCheckerPassTest, instance_norm_NEG)
358 SimpleInstanceNormGraph invalid_graph(false);
360 luci::QuantizePreCheckerPass checker{};
362 EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
365 // Test TransposeConv
366 TEST(QuantizePreCheckerPassTest, transpose_conv)
368 SimpleTransposeConvGraph valid_graph(true);
370 luci::QuantizePreCheckerPass checker{};
372 EXPECT_NO_THROW(checker.run(&valid_graph.g));
375 TEST(QuantizePreCheckerPassTest, transpose_conv_NEG)
377 SimpleTransposeConvGraph invalid_graph(false);
379 luci::QuantizePreCheckerPass checker{};
381 EXPECT_ANY_THROW(checker.run(&invalid_graph.g));
385 TEST(QuantizePreCheckerPassTest, prelu)
387 SimplePReluGraph valid_graph(true);
389 luci::QuantizePreCheckerPass checker{};
391 EXPECT_NO_THROW(checker.run(&valid_graph.g));
394 TEST(QuantizePreCheckerPassTest, prelu_NEG)
396 SimplePReluGraph invalid_graph(false);
398 luci::QuantizePreCheckerPass checker{};
400 EXPECT_ANY_THROW(checker.run(&invalid_graph.g));