2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
16 #include <gtest/gtest.h>
18 #include "Quantizer.h"
19 #include "TestHelper.h"
21 #include <luci/IR/CircleNodes.h>
28 class AddGraph final : public SimpleGraph
31 void initInput(loco::Node *input) override
33 auto ci_input = loco::must_cast<luci::CircleNode *>(input);
37 void initMinMax(luci::CircleNode *node)
39 auto qparam = std::make_unique<luci::CircleQuantParam>();
40 qparam->min.assign(1, _a_min);
41 qparam->max.assign(1, _a_max);
42 node->quantparam(std::move(qparam));
45 loco::Node *insertGraphBody(loco::Node *input) override
47 _add = _g->nodes()->create<luci::CircleAdd>();
48 _beta = _g->nodes()->create<luci::CircleConst>();
50 _add->dtype(loco::DataType::FLOAT32);
51 _beta->dtype(loco::DataType::FLOAT32);
53 uint32_t channel_size = 16;
54 _add->shape({1, _channel_size, _width, _height});
55 _beta->shape({1, _channel_size, _width, _height});
57 _beta->size<loco::DataType::FLOAT32>(channel_size);
60 _add->fusedActivationFunction(luci::FusedActFunc::NONE);
72 luci::CircleAdd *_add = nullptr;
73 luci::CircleConst *_beta = nullptr;
78 TEST(CircleMPQSolverQuantizerTest, verifyResultsTest)
80 auto m = luci::make_module();
84 float range = g._a_max - g._a_min;
85 g.transfer_to(m.get());
87 std::string def_quant = "uint8";
88 mpqsolver::core::Quantizer quantizer(def_quant, def_quant);
89 mpqsolver::core::LayerParams params;
90 auto res = quantizer.quantize(m.get(), def_quant, params);
92 auto quant_param = add->quantparam();
93 EXPECT_TRUE(quant_param != nullptr);
94 EXPECT_TRUE(quant_param->scale.size() == 1);
95 EXPECT_FLOAT_EQ(quant_param->scale[0], range / 255.f);
96 EXPECT_TRUE(quant_param->zerop.size() == 1);
97 EXPECT_TRUE(quant_param->zerop[0] == 128);
100 TEST(CircleMPQSolverQuantizerTest, verifyResultsTest_NEG)
102 std::string def_quant = "uint8";
103 mpqsolver::core::Quantizer quantizer(def_quant, def_quant);
104 mpqsolver::core::LayerParams params;
105 auto res = quantizer.quantize(nullptr, def_quant, params);