2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/RequantizePass.h"
19 #include "helpers/CreateCircleConst.h"
21 #include <luci/test/TestIOGraph.h>
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/IR/CircleQuantParam.h>
27 #include <gtest/gtest.h>
30 using namespace luci::test;
36 * Simple graph for test
40 * [IFM (S8)] [W (S8)] [B (S32)]
52 * [IFM (U8)] [W (U8)] [B (S32)]
65 S8FCGraphlet() = default;
66 virtual ~S8FCGraphlet() = default;
68 void init(loco::Graph *g, const ShapeU32 out_shape, const ShapeU32 w_shape,
69 const ShapeU32 bias_shape)
71 _fc = g->nodes()->create<CircleFullyConnected>();
73 _x->dtype(loco::DataType::S8);
75 auto quantparam = std::make_unique<CircleQuantParam>();
76 quantparam->scale.push_back(1.0);
77 quantparam->zerop.push_back(0);
78 quantparam->quantized_dimension = 0;
79 _x->quantparam(std::move(quantparam));
82 _weights = create_const_node<int8_t>(g, loco::DataType::S8, w_shape, 1.0);
84 auto w_qparam = std::make_unique<CircleQuantParam>();
85 std::vector<float> w_scale(_weights->dim(0).value(), 1.0);
86 std::vector<int64_t> w_zp(_weights->dim(0).value(), 0);
87 w_qparam->scale = w_scale;
88 w_qparam->zerop = w_zp;
89 w_qparam->quantized_dimension = 0;
90 _weights->quantparam(std::move(w_qparam));
92 _fc->weights(_weights);
94 _bias = create_const_node<int32_t>(g, loco::DataType::S32, bias_shape, 1.0);
96 auto b_qparam = std::make_unique<CircleQuantParam>();
97 const auto bias_size = _bias->size<loco::DataType::S32>();
98 std::vector<float> b_scale(bias_size, 1.0);
99 std::vector<int64_t> b_zp(bias_size, 0);
100 b_qparam->scale = b_scale;
101 b_qparam->zerop = b_zp;
102 b_qparam->quantized_dimension = 0;
103 _bias->quantparam(std::move(b_qparam));
106 _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
107 _fc->dtype(loco::DataType::S8);
108 _fc->shape(out_shape);
112 auto quantparam = std::make_unique<CircleQuantParam>();
113 quantparam->scale.push_back(1.0);
114 quantparam->zerop.push_back(0);
115 quantparam->quantized_dimension = 0;
116 _fc->quantparam(std::move(quantparam));
121 CircleFullyConnected *_fc = nullptr;
122 CircleInput *_x = nullptr;
123 CircleConst *_weights = nullptr;
124 CircleConst *_bias = nullptr;
127 struct S8FCGraph final : public TestIGraphlet, public TestOGraphlet, public S8FCGraphlet
129 void init(const ShapeU32 in_shape, const ShapeU32 w_shape, const ShapeU32 out_shape,
130 const ShapeU32 bias_shape)
132 TestIGraphlet::init(g(), in_shape);
133 TestOGraphlet::init(g(), out_shape);
135 S8FCGraphlet::init(g(), out_shape, w_shape, bias_shape);
140 class RequantizeS8ToU8FCTest : public ::testing::Test
148 TEST(RequantizePassTest, name)
150 luci::RequantizePass pass(loco::DataType::FLOAT32, loco::DataType::U8);
151 auto const name = pass.name();
152 ASSERT_NE(nullptr, name);
155 TEST_F(RequantizeS8ToU8FCTest, FC)
157 g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */,
158 {1, 256} /* bias shape*/);
160 luci::RequantizePass rq(loco::DataType::S8, loco::DataType::U8);
163 EXPECT_EQ(loco::DataType::U8, g._x->dtype());
164 EXPECT_EQ(loco::DataType::U8, g._fc->dtype());
165 EXPECT_EQ(loco::DataType::U8, g._weights->dtype());
166 EXPECT_EQ(loco::DataType::S32, g._bias->dtype());
169 TEST_F(RequantizeS8ToU8FCTest, FC_wrong_dtype_NEG)
171 g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */,
172 {1, 256} /* bias shape*/);
175 luci::RequantizePass rq(loco::DataType::U8, loco::DataType::S8);
178 EXPECT_EQ(loco::DataType::S8, g._x->dtype());
179 EXPECT_EQ(loco::DataType::S8, g._fc->dtype());
180 EXPECT_EQ(loco::DataType::S8, g._weights->dtype());
181 EXPECT_EQ(loco::DataType::S32, g._bias->dtype());