8575156d2ca55d2e45df17304c8fe8e592e210d0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / bisection / Quantizer.test.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <gtest/gtest.h>
17
18 #include "Quantizer.h"
19 #include "TestHelper.h"
20
21 #include <luci/IR/CircleNodes.h>
22
23 #include <cmath>
24
25 namespace
26 {
27
28 class AddGraph final : public SimpleGraph
29 {
30 protected:
31   void initInput(loco::Node *input) override
32   {
33     auto ci_input = loco::must_cast<luci::CircleNode *>(input);
34     initMinMax(ci_input);
35   }
36
37   void initMinMax(luci::CircleNode *node)
38   {
39     auto qparam = std::make_unique<luci::CircleQuantParam>();
40     qparam->min.assign(1, _a_min);
41     qparam->max.assign(1, _a_max);
42     node->quantparam(std::move(qparam));
43   }
44
45   loco::Node *insertGraphBody(loco::Node *input) override
46   {
47     _add = _g->nodes()->create<luci::CircleAdd>();
48     _beta = _g->nodes()->create<luci::CircleConst>();
49
50     _add->dtype(loco::DataType::FLOAT32);
51     _beta->dtype(loco::DataType::FLOAT32);
52
53     uint32_t channel_size = 16;
54     _add->shape({1, _channel_size, _width, _height});
55     _beta->shape({1, _channel_size, _width, _height});
56
57     _beta->size<loco::DataType::FLOAT32>(channel_size);
58     _add->x(input);
59     _add->y(_beta);
60     _add->fusedActivationFunction(luci::FusedActFunc::NONE);
61
62     _add->name("add");
63     _beta->name("beta");
64     initMinMax(_add);
65
66     return _add;
67   }
68
69 public:
70   float _a_min = -1.f;
71   float _a_max = 1.f;
72   luci::CircleAdd *_add = nullptr;
73   luci::CircleConst *_beta = nullptr;
74 };
75
76 } // namespace
77
78 TEST(CircleMPQSolverQuantizerTest, verifyResultsTest)
79 {
80   auto m = luci::make_module();
81   AddGraph g;
82   g.init();
83   auto add = g._add;
84   float range = g._a_max - g._a_min;
85   g.transfer_to(m.get());
86
87   std::string def_quant = "uint8";
88   mpqsolver::bisection::Quantizer quantizer(def_quant, def_quant);
89   mpqsolver::bisection::LayerParams params;
90   auto res = quantizer.quantize(m.get(), def_quant, params);
91   EXPECT_TRUE(res);
92   auto quant_param = add->quantparam();
93   EXPECT_TRUE(quant_param != nullptr);
94   EXPECT_TRUE(quant_param->scale.size() == 1);
95   EXPECT_FLOAT_EQ(quant_param->scale[0], range / 255.f);
96   EXPECT_TRUE(quant_param->zerop.size() == 1);
97   EXPECT_TRUE(quant_param->zerop[0] == 128);
98 }
99
100 TEST(CircleMPQSolverQuantizerTest, verifyResultsTest_NEG)
101 {
102   std::string def_quant = "uint8";
103   mpqsolver::bisection::Quantizer quantizer(def_quant, def_quant);
104   mpqsolver::bisection::LayerParams params;
105   auto res = quantizer.quantize(nullptr, def_quant, params);
106   EXPECT_TRUE(!res);
107 }