2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "QuantizeBias.h"
19 #include "helpers/CreateCircleConst.h"
21 #include <luci/test/TestIOGraph.h>
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/IR/CircleQuantParam.h>
25 #include <gtest/gtest.h>
32 using namespace luci::test;
35 * Simple graph for test
39 * [IFM] [WEIGHTS] [BIAS(FP32)]
47 * [IFM] [WEIGHTS] [BIAS(Quantized)]
56 Q8FCGraphlet() = default;
57 virtual ~Q8FCGraphlet() = default;
59 void init(loco::Graph *g, const ShapeU32 out_shape, const ShapeU32 w_shape,
60 const ShapeU32 bias_shape, const float bv)
62 _fc = g->nodes()->create<luci::CircleFullyConnected>();
64 _x->dtype(loco::DataType::U8);
66 auto quantparam = std::make_unique<CircleQuantParam>();
67 quantparam->scale.push_back(1.0);
68 quantparam->zerop.push_back(0);
69 quantparam->quantized_dimension = 0;
70 _x->quantparam(std::move(quantparam));
73 auto weights = create_const_node<uint8_t>(g, loco::DataType::U8, w_shape, 1.0);
74 auto w_qparam = std::make_unique<CircleQuantParam>();
75 std::vector<float> w_scale(weights->dim(0).value(), 1.0);
76 std::vector<int64_t> w_zp(weights->dim(0).value(), 0);
77 w_qparam->scale = w_scale;
78 w_qparam->zerop = w_zp;
79 w_qparam->quantized_dimension = 0;
80 weights->quantparam(std::move(w_qparam));
81 _fc->weights(weights);
82 _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
83 _fc->dtype(loco::DataType::U8);
84 _fc->shape(out_shape);
85 auto l = _fc->dim(_fc->rank() - 1).value();
86 _fc->bias(create_const_node(g, loco::DataType::FLOAT32, bias_shape, bv));
89 auto quantparam = std::make_unique<CircleQuantParam>();
90 quantparam->scale.push_back(1.0);
91 quantparam->zerop.push_back(0);
92 quantparam->quantized_dimension = 0;
93 _fc->quantparam(std::move(quantparam));
98 luci::CircleFullyConnected *fc() { return _fc; }
101 luci::CircleFullyConnected *_fc = nullptr;
102 luci::CircleInput *_x = nullptr;
105 struct Q8FCGraph final : public TestIGraphlet, public TestOGraphlet, public Q8FCGraphlet
107 void init(const ShapeU32 in_shape, const ShapeU32 w_shape, const ShapeU32 out_shape,
108 const ShapeU32 bias_shape, const float bv)
110 TestIGraphlet::init(g(), in_shape);
111 TestOGraphlet::init(g(), out_shape);
113 Q8FCGraphlet::init(g(), out_shape, w_shape, bias_shape, bv);
118 class CQ8QuantizeBiasFCTest : public ::testing::Test
122 luci::QuantizeBias qb{loco::DataType::FLOAT32, loco::DataType::U8,
123 luci::QuantizationGranularity::ChannelWise};
128 TEST_F(CQ8QuantizeBiasFCTest, fully_connected)
130 g.init({1, 18, 80}, {256, 80}, {18, 256}, {1, 256}, 1);
133 auto bias = loco::must_cast<CircleConst *>(g.fc()->bias());
134 auto qparam = bias->quantparam();
136 EXPECT_NE(nullptr, qparam);
137 EXPECT_EQ(256, qparam->scale.size());
138 EXPECT_EQ(256, qparam->zerop.size());
139 EXPECT_EQ(1, qparam->quantized_dimension);
142 TEST_F(CQ8QuantizeBiasFCTest, wrong_bias_shape_NEG)
144 g.init({1, 18, 80}, {256, 80}, {18, 256}, {1, 2, 128}, 1);
145 EXPECT_ANY_THROW(g.fc()->accept(&qb)); // Wrong bias shape