2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/FoldDequantizePass.h"
18 #include "PassTestGraphs.h"
20 #include <gtest/gtest.h>
25 template <loco::DataType DT>
26 class FoldDequantizeTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
29 FoldDequantizeTest() : luci::ConstantFoldingAddTestGraph({2, 2, 2}, DT) {}
31 virtual void SetUp() { init(); }
33 loco::Node *createFoldedPattern() override
35 _dequantize = _g.nodes()->template create<luci::CircleDequantize>();
36 _input = _g.nodes()->template create<luci::CircleConst>();
38 _dequantize->dtype(loco::DataType::FLOAT32);
41 _input->shape({2, 2, 2});
44 _input->at<DT>(0) = 0;
45 _input->at<DT>(1) = 1;
46 _input->at<DT>(2) = 2;
47 _input->at<DT>(3) = 3;
48 _input->at<DT>(4) = 4;
49 _input->at<DT>(5) = 5;
50 _input->at<DT>(6) = 6;
51 _input->at<DT>(7) = 7;
53 auto qparam = std::make_unique<luci::CircleQuantParam>();
54 qparam->quantized_dimension = 1;
55 qparam->scale.push_back(5.0);
56 qparam->scale.push_back(10.0);
57 qparam->zerop.push_back(1);
58 qparam->zerop.push_back(2);
59 _input->quantparam(std::move(qparam));
61 _dequantize->input(_input);
63 _dequantize->name("dequantize");
64 _input->name("input");
69 void createScalarPattern()
73 _input->at<DT>(0) = 1;
75 auto qparam = std::make_unique<luci::CircleQuantParam>();
76 qparam->quantized_dimension = 0;
77 qparam->scale.push_back(1.0);
78 qparam->zerop.push_back(0);
79 _input->quantparam(std::move(qparam));
82 void createNotFoldablePattern() { _input->quantparam(nullptr); }
85 luci::CircleDequantize *_dequantize = nullptr;
86 luci::CircleConst *_input = nullptr;
89 class S8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S8>
93 class S16FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S16>
97 class S32FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S32>
101 class S64FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S64>
105 class U8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::U8>
109 class F16FoldDequantizeTest : public luci::ConstantFoldingTestGraph, public ::testing::Test
112 F16FoldDequantizeTest() : ConstantFoldingTestGraph({2, 2}, loco::DataType::FLOAT16) {}
114 virtual void SetUp() { init(); }
116 loco::Node *createFoldedPattern() override
118 const auto DT = loco::DataType::FLOAT16;
119 _dequantize = _g.nodes()->create<luci::CircleDequantize>();
120 _f16const = _g.nodes()->create<luci::CircleConst>();
122 _dequantize->dtype(loco::DataType::FLOAT32);
123 _f16const->dtype(DT);
125 _f16const->shape({2, 2});
127 _f16const->size<loco::DataType::FLOAT16>(4);
128 _f16const->at<DT>(0) = 49408; // -2.5f
129 _f16const->at<DT>(1) = 47104; // -0.5f
130 _f16const->at<DT>(2) = 0; // 0.0f
131 _f16const->at<DT>(3) = 15872; // 1.5f
132 // NOTE how to get uint16_t value of float16 ?
133 // Use compiler/souschef/src/Gaussian.cpp GaussianFloat16DataChef::generate()
134 // uint16_t value = fp16_ieee_from_fp32_value(-2.5);
135 // printf("-2.5 = %u\r\n", value);
137 _dequantize->input(_f16const);
139 _dequantize->name("dequantize");
140 _f16const->name("input");
142 _output->from(_dequantize);
147 void createNotFoldablePattern() { _dequantize->input(_input); }
150 luci::CircleConst *getFoldedPattern() override
152 return dynamic_cast<luci::CircleConst *>(_output->from());
155 void init() override { createFoldedPattern(); }
158 luci::CircleDequantize *_dequantize = nullptr;
159 luci::CircleConst *_f16const = nullptr;
164 TEST(FoldDequantizePassTest, name)
166 luci::FoldDequantizePass pass;
167 auto const name = pass.name();
168 ASSERT_NE(nullptr, name);
171 TEST_F(U8FoldDequantizeTest, fold_dequant_basic)
173 luci::FoldDequantizePass pass;
174 while (pass.run(graph()))
177 auto folded_const = getFoldedPattern();
178 EXPECT_NE(nullptr, folded_const);
180 // Chec type, shape, values of folded const
181 EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
182 EXPECT_EQ(3, folded_const->rank());
183 EXPECT_EQ(2, folded_const->dim(0).value());
184 EXPECT_EQ(2, folded_const->dim(1).value());
185 EXPECT_EQ(2, folded_const->dim(2).value());
186 EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
187 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
188 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
189 EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
190 EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
191 EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
192 EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
193 EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
196 TEST_F(U8FoldDequantizeTest, fold_dequant_basic_NEG)
198 createNotFoldablePattern();
200 luci::FoldDequantizePass pass;
201 while (pass.run(graph()))
204 auto folded_const = getFoldedPattern();
205 EXPECT_EQ(nullptr, folded_const);
208 TEST_F(S8FoldDequantizeTest, fold_dequant_basic)
210 luci::FoldDequantizePass pass;
211 while (pass.run(graph()))
214 auto folded_const = getFoldedPattern();
215 EXPECT_NE(nullptr, folded_const);
217 // Chec type, shape, values of folded const
218 EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
219 EXPECT_EQ(3, folded_const->rank());
220 EXPECT_EQ(2, folded_const->dim(0).value());
221 EXPECT_EQ(2, folded_const->dim(1).value());
222 EXPECT_EQ(2, folded_const->dim(2).value());
223 EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
224 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
225 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
226 EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
227 EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
228 EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
229 EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
230 EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
233 TEST_F(S8FoldDequantizeTest, fold_dequant_basic_NEG)
235 createNotFoldablePattern();
237 luci::FoldDequantizePass pass;
238 while (pass.run(graph()))
241 auto folded_const = getFoldedPattern();
242 EXPECT_EQ(nullptr, folded_const);
245 TEST_F(S16FoldDequantizeTest, fold_dequant_basic)
247 luci::FoldDequantizePass pass;
248 while (pass.run(graph()))
251 auto folded_const = getFoldedPattern();
252 EXPECT_NE(nullptr, folded_const);
254 // Chec type, shape, values of folded const
255 EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
256 EXPECT_EQ(3, folded_const->rank());
257 EXPECT_EQ(2, folded_const->dim(0).value());
258 EXPECT_EQ(2, folded_const->dim(1).value());
259 EXPECT_EQ(2, folded_const->dim(2).value());
260 EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
261 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
262 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
263 EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
264 EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
265 EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
266 EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
267 EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
270 TEST_F(S16FoldDequantizeTest, fold_dequant_basic_NEG)
272 createNotFoldablePattern();
274 luci::FoldDequantizePass pass;
275 while (pass.run(graph()))
278 auto folded_const = getFoldedPattern();
279 EXPECT_EQ(nullptr, folded_const);
282 TEST_F(S32FoldDequantizeTest, fold_dequant_basic)
284 luci::FoldDequantizePass pass;
285 while (pass.run(graph()))
288 auto folded_const = getFoldedPattern();
289 EXPECT_NE(nullptr, folded_const);
291 // Chec type, shape, values of folded const
292 EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
293 EXPECT_EQ(3, folded_const->rank());
294 EXPECT_EQ(2, folded_const->dim(0).value());
295 EXPECT_EQ(2, folded_const->dim(1).value());
296 EXPECT_EQ(2, folded_const->dim(2).value());
297 EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
298 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
299 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
300 EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
301 EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
302 EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
303 EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
304 EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
307 TEST_F(S32FoldDequantizeTest, fold_dequant_basic_NEG)
309 createNotFoldablePattern();
311 luci::FoldDequantizePass pass;
312 while (pass.run(graph()))
315 auto folded_const = getFoldedPattern();
316 EXPECT_EQ(nullptr, folded_const);
319 TEST_F(S64FoldDequantizeTest, fold_dequant_basic)
321 luci::FoldDequantizePass pass;
322 while (pass.run(graph()))
325 auto folded_const = getFoldedPattern();
326 EXPECT_NE(nullptr, folded_const);
328 // Chec type, shape, values of folded const
329 EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
330 EXPECT_EQ(3, folded_const->rank());
331 EXPECT_EQ(2, folded_const->dim(0).value());
332 EXPECT_EQ(2, folded_const->dim(1).value());
333 EXPECT_EQ(2, folded_const->dim(2).value());
334 EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
335 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
336 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
337 EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
338 EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
339 EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
340 EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
341 EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
344 TEST_F(S64FoldDequantizeTest, fold_dequant_basic_NEG)
346 createNotFoldablePattern();
348 luci::FoldDequantizePass pass;
349 while (pass.run(graph()))
352 auto folded_const = getFoldedPattern();
353 EXPECT_EQ(nullptr, folded_const);
356 TEST_F(U8FoldDequantizeTest, fold_dequant_scalar)
358 createScalarPattern();
360 luci::FoldDequantizePass pass;
361 while (pass.run(graph()))
364 auto folded_const = getFoldedPattern();
365 EXPECT_NE(nullptr, folded_const);
367 // Check type, shape, values of folded const
368 EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
369 EXPECT_EQ(0, folded_const->rank());
370 EXPECT_EQ(1.0, folded_const->at<loco::DataType::FLOAT32>(0));
373 TEST_F(F16FoldDequantizeTest, fold_dequant_basic)
375 luci::FoldDequantizePass pass;
376 while (pass.run(graph()))
379 auto folded_const = getFoldedPattern();
380 EXPECT_NE(nullptr, folded_const);
382 // Chec type, shape, values of folded const
383 EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
384 EXPECT_EQ(2, folded_const->rank());
385 EXPECT_EQ(2, folded_const->dim(0).value());
386 EXPECT_EQ(2, folded_const->dim(1).value());
387 EXPECT_EQ(-2.5, folded_const->at<loco::DataType::FLOAT32>(0));
388 EXPECT_EQ(-0.5, folded_const->at<loco::DataType::FLOAT32>(1));
389 EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
390 EXPECT_EQ(1.5, folded_const->at<loco::DataType::FLOAT32>(3));
393 TEST_F(F16FoldDequantizeTest, fold_dequant_basic_NEG)
395 createNotFoldablePattern();
397 luci::FoldDequantizePass pass;
398 while (pass.run(graph()))
401 auto folded_const = getFoldedPattern();
402 EXPECT_EQ(nullptr, folded_const);