fb5b6adc02287dd9e088175b78caa135b9b33a23
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / FoldDequantizePass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/FoldDequantizePass.h"
18 #include "PassTestGraphs.h"
19
20 #include <gtest/gtest.h>
21
22 namespace
23 {
24
25 template <loco::DataType DT>
26 class FoldDequantizeTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
27 {
28 public:
29   FoldDequantizeTest() : luci::ConstantFoldingAddTestGraph({2, 2, 2}, DT) {}
30
31   virtual void SetUp() { init(); }
32
33   loco::Node *createFoldedPattern() override
34   {
35     _dequantize = _g.nodes()->create<luci::CircleDequantize>();
36     _input = _g.nodes()->create<luci::CircleConst>();
37
38     _dequantize->dtype(loco::DataType::FLOAT32);
39     _input->dtype(DT);
40
41     _input->shape({2, 2, 2});
42
43     _input->size<DT>(8);
44     _input->at<DT>(0) = 0;
45     _input->at<DT>(1) = 1;
46     _input->at<DT>(2) = 2;
47     _input->at<DT>(3) = 3;
48     _input->at<DT>(4) = 4;
49     _input->at<DT>(5) = 5;
50     _input->at<DT>(6) = 6;
51     _input->at<DT>(7) = 7;
52
53     auto qparam = std::make_unique<luci::CircleQuantParam>();
54     qparam->quantized_dimension = 1;
55     qparam->scale.push_back(5.0);
56     qparam->scale.push_back(10.0);
57     qparam->zerop.push_back(1);
58     qparam->zerop.push_back(2);
59     _input->quantparam(std::move(qparam));
60
61     _dequantize->input(_input);
62
63     _dequantize->name("dequantize");
64     _input->name("input");
65
66     return _dequantize;
67   }
68
69   void createScalarPattern()
70   {
71     _input->rank(0);
72     _input->size<DT>(1);
73     _input->at<DT>(0) = 1;
74
75     auto qparam = std::make_unique<luci::CircleQuantParam>();
76     qparam->quantized_dimension = 0;
77     qparam->scale.push_back(1.0);
78     qparam->zerop.push_back(0);
79     _input->quantparam(std::move(qparam));
80   }
81
82   void createNotFoldablePattern() { _input->quantparam(nullptr); }
83
84 protected:
85   luci::CircleDequantize *_dequantize = nullptr;
86   luci::CircleConst *_input = nullptr;
87 };
88
89 class S8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S8>
90 {
91 };
92
93 class S16FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S16>
94 {
95 };
96
97 class S32FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S32>
98 {
99 };
100
101 class S64FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S64>
102 {
103 };
104
105 class U8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::U8>
106 {
107 };
108
109 class F16FoldDequantizeTest : public luci::ConstantFoldingTestGraph, public ::testing::Test
110 {
111 public:
112   F16FoldDequantizeTest() : ConstantFoldingTestGraph({2, 2}, loco::DataType::FLOAT16) {}
113
114   virtual void SetUp() { init(); }
115
116   loco::Node *createFoldedPattern() override
117   {
118     const auto DT = loco::DataType::FLOAT16;
119     _dequantize = _g.nodes()->create<luci::CircleDequantize>();
120     _f16const = _g.nodes()->create<luci::CircleConst>();
121
122     _dequantize->dtype(loco::DataType::FLOAT32);
123     _f16const->dtype(DT);
124
125     _f16const->shape({2, 2});
126
127     _f16const->size<loco::DataType::FLOAT16>(4);
128     _f16const->at<DT>(0) = 49408; // -2.5f
129     _f16const->at<DT>(1) = 47104; // -0.5f
130     _f16const->at<DT>(2) = 0;     //  0.0f
131     _f16const->at<DT>(3) = 15872; //  1.5f
132     // NOTE how to get uint16_t value of float16 ?
133     // Use compiler/souschef/src/Gaussian.cpp GaussianFloat16DataChef::generate()
134     //   uint16_t value = fp16_ieee_from_fp32_value(-2.5);
135     //   printf("-2.5 = %u\r\n", value);
136
137     _dequantize->input(_f16const);
138
139     _dequantize->name("dequantize");
140     _f16const->name("input");
141
142     _output->from(_dequantize);
143
144     return _dequantize;
145   }
146
147   void createNotFoldablePattern() { _dequantize->input(_input); }
148
149 protected:
150   luci::CircleConst *getFoldedPattern() override
151   {
152     return dynamic_cast<luci::CircleConst *>(_output->from());
153   }
154
155   void init() override { createFoldedPattern(); }
156
157 protected:
158   luci::CircleDequantize *_dequantize = nullptr;
159   luci::CircleConst *_f16const = nullptr;
160 };
161
162 } // namespace
163
164 TEST(FoldDequantizePassTest, name)
165 {
166   luci::FoldDequantizePass pass;
167   auto const name = pass.name();
168   ASSERT_NE(nullptr, name);
169 }
170
171 TEST_F(U8FoldDequantizeTest, fold_dequant_basic)
172 {
173   luci::FoldDequantizePass pass;
174   while (pass.run(graph()))
175     ;
176
177   auto folded_const = getFoldedPattern();
178   EXPECT_NE(nullptr, folded_const);
179
180   // Chec type, shape, values of folded const
181   EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
182   EXPECT_EQ(3, folded_const->rank());
183   EXPECT_EQ(2, folded_const->dim(0).value());
184   EXPECT_EQ(2, folded_const->dim(1).value());
185   EXPECT_EQ(2, folded_const->dim(2).value());
186   EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
187   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
188   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
189   EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
190   EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
191   EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
192   EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
193   EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
194 }
195
196 TEST_F(U8FoldDequantizeTest, fold_dequant_basic_NEG)
197 {
198   createNotFoldablePattern();
199
200   luci::FoldDequantizePass pass;
201   while (pass.run(graph()))
202     ;
203
204   auto folded_const = getFoldedPattern();
205   EXPECT_EQ(nullptr, folded_const);
206 }
207
208 TEST_F(S8FoldDequantizeTest, fold_dequant_basic)
209 {
210   luci::FoldDequantizePass pass;
211   while (pass.run(graph()))
212     ;
213
214   auto folded_const = getFoldedPattern();
215   EXPECT_NE(nullptr, folded_const);
216
217   // Chec type, shape, values of folded const
218   EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
219   EXPECT_EQ(3, folded_const->rank());
220   EXPECT_EQ(2, folded_const->dim(0).value());
221   EXPECT_EQ(2, folded_const->dim(1).value());
222   EXPECT_EQ(2, folded_const->dim(2).value());
223   EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
224   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
225   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
226   EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
227   EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
228   EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
229   EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
230   EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
231 }
232
233 TEST_F(S8FoldDequantizeTest, fold_dequant_basic_NEG)
234 {
235   createNotFoldablePattern();
236
237   luci::FoldDequantizePass pass;
238   while (pass.run(graph()))
239     ;
240
241   auto folded_const = getFoldedPattern();
242   EXPECT_EQ(nullptr, folded_const);
243 }
244
245 TEST_F(S16FoldDequantizeTest, fold_dequant_basic)
246 {
247   luci::FoldDequantizePass pass;
248   while (pass.run(graph()))
249     ;
250
251   auto folded_const = getFoldedPattern();
252   EXPECT_NE(nullptr, folded_const);
253
254   // Chec type, shape, values of folded const
255   EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
256   EXPECT_EQ(3, folded_const->rank());
257   EXPECT_EQ(2, folded_const->dim(0).value());
258   EXPECT_EQ(2, folded_const->dim(1).value());
259   EXPECT_EQ(2, folded_const->dim(2).value());
260   EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
261   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
262   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
263   EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
264   EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
265   EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
266   EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
267   EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
268 }
269
270 TEST_F(S16FoldDequantizeTest, fold_dequant_basic_NEG)
271 {
272   createNotFoldablePattern();
273
274   luci::FoldDequantizePass pass;
275   while (pass.run(graph()))
276     ;
277
278   auto folded_const = getFoldedPattern();
279   EXPECT_EQ(nullptr, folded_const);
280 }
281
282 TEST_F(S32FoldDequantizeTest, fold_dequant_basic)
283 {
284   luci::FoldDequantizePass pass;
285   while (pass.run(graph()))
286     ;
287
288   auto folded_const = getFoldedPattern();
289   EXPECT_NE(nullptr, folded_const);
290
291   // Chec type, shape, values of folded const
292   EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
293   EXPECT_EQ(3, folded_const->rank());
294   EXPECT_EQ(2, folded_const->dim(0).value());
295   EXPECT_EQ(2, folded_const->dim(1).value());
296   EXPECT_EQ(2, folded_const->dim(2).value());
297   EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
298   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
299   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
300   EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
301   EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
302   EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
303   EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
304   EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
305 }
306
307 TEST_F(S32FoldDequantizeTest, fold_dequant_basic_NEG)
308 {
309   createNotFoldablePattern();
310
311   luci::FoldDequantizePass pass;
312   while (pass.run(graph()))
313     ;
314
315   auto folded_const = getFoldedPattern();
316   EXPECT_EQ(nullptr, folded_const);
317 }
318
319 TEST_F(S64FoldDequantizeTest, fold_dequant_basic)
320 {
321   luci::FoldDequantizePass pass;
322   while (pass.run(graph()))
323     ;
324
325   auto folded_const = getFoldedPattern();
326   EXPECT_NE(nullptr, folded_const);
327
328   // Chec type, shape, values of folded const
329   EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
330   EXPECT_EQ(3, folded_const->rank());
331   EXPECT_EQ(2, folded_const->dim(0).value());
332   EXPECT_EQ(2, folded_const->dim(1).value());
333   EXPECT_EQ(2, folded_const->dim(2).value());
334   EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
335   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
336   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
337   EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
338   EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
339   EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
340   EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
341   EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
342 }
343
344 TEST_F(S64FoldDequantizeTest, fold_dequant_basic_NEG)
345 {
346   createNotFoldablePattern();
347
348   luci::FoldDequantizePass pass;
349   while (pass.run(graph()))
350     ;
351
352   auto folded_const = getFoldedPattern();
353   EXPECT_EQ(nullptr, folded_const);
354 }
355
356 TEST_F(U8FoldDequantizeTest, fold_dequant_scalar)
357 {
358   createScalarPattern();
359
360   luci::FoldDequantizePass pass;
361   while (pass.run(graph()))
362     ;
363
364   auto folded_const = getFoldedPattern();
365   EXPECT_NE(nullptr, folded_const);
366
367   // Check type, shape, values of folded const
368   EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
369   EXPECT_EQ(0, folded_const->rank());
370   EXPECT_EQ(1.0, folded_const->at<loco::DataType::FLOAT32>(0));
371 }
372
373 TEST_F(F16FoldDequantizeTest, fold_dequant_basic)
374 {
375   luci::FoldDequantizePass pass;
376   while (pass.run(graph()))
377     ;
378
379   auto folded_const = getFoldedPattern();
380   EXPECT_NE(nullptr, folded_const);
381
382   // Chec type, shape, values of folded const
383   EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
384   EXPECT_EQ(2, folded_const->rank());
385   EXPECT_EQ(2, folded_const->dim(0).value());
386   EXPECT_EQ(2, folded_const->dim(1).value());
387   EXPECT_EQ(-2.5, folded_const->at<loco::DataType::FLOAT32>(0));
388   EXPECT_EQ(-0.5, folded_const->at<loco::DataType::FLOAT32>(1));
389   EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
390   EXPECT_EQ(1.5, folded_const->at<loco::DataType::FLOAT32>(3));
391 }
392
393 TEST_F(F16FoldDequantizeTest, fold_dequant_basic_NEG)
394 {
395   createNotFoldablePattern();
396
397   luci::FoldDequantizePass pass;
398   while (pass.run(graph()))
399     ;
400
401   auto folded_const = getFoldedPattern();
402   EXPECT_EQ(nullptr, folded_const);
403 }