Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / RequantizePass.test.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/RequantizePass.h"
18
19 #include "helpers/CreateCircleConst.h"
20
21 #include <luci/test/TestIOGraph.h>
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/IR/CircleQuantParam.h>
24
25 #include <vector>
26
27 #include <gtest/gtest.h>
28
29 using namespace luci;
30 using namespace luci::test;
31
32 namespace
33 {
34
35 /**
36  *  Simple graph for test
37  *
38  *  BEFORE
39  *
40  * [IFM (S8)] [W (S8)] [B (S32)]
41  *       |       |        |
42  *       +-------+--------+
43  *               |
44  *               V
45  *              [FC]
46  *               |
47  *               V
48  *           [OFM(S8)]
49  *
50  *  AFTER
51  *
52  * [IFM (U8)] [W (U8)] [B (S32)]
53  *       |       |        |
54  *       +-------+--------+
55  *               |
56  *               V
57  *              [FC]
58  *               |
59  *               V
60  *           [OFM(U8)]
61  */
62 struct S8FCGraphlet
63 {
64 public:
65   S8FCGraphlet() = default;
66   virtual ~S8FCGraphlet() = default;
67
68   void init(loco::Graph *g, const ShapeU32 out_shape, const ShapeU32 w_shape,
69             const ShapeU32 bias_shape)
70   {
71     _fc = g->nodes()->create<CircleFullyConnected>();
72     _fc->input(_x);
73     _x->dtype(loco::DataType::S8);
74     {
75       auto quantparam = std::make_unique<CircleQuantParam>();
76       quantparam->scale.push_back(1.0);
77       quantparam->zerop.push_back(0);
78       quantparam->quantized_dimension = 0;
79       _x->quantparam(std::move(quantparam));
80     }
81
82     _weights = create_const_node<int8_t>(g, loco::DataType::S8, w_shape, 1.0);
83     {
84       auto w_qparam = std::make_unique<CircleQuantParam>();
85       std::vector<float> w_scale(_weights->dim(0).value(), 1.0);
86       std::vector<int64_t> w_zp(_weights->dim(0).value(), 0);
87       w_qparam->scale = w_scale;
88       w_qparam->zerop = w_zp;
89       w_qparam->quantized_dimension = 0;
90       _weights->quantparam(std::move(w_qparam));
91     }
92     _fc->weights(_weights);
93
94     _bias = create_const_node<int32_t>(g, loco::DataType::S32, bias_shape, 1.0);
95     {
96       auto b_qparam = std::make_unique<CircleQuantParam>();
97       const auto bias_size = _bias->size<loco::DataType::S32>();
98       std::vector<float> b_scale(bias_size, 1.0);
99       std::vector<int64_t> b_zp(bias_size, 0);
100       b_qparam->scale = b_scale;
101       b_qparam->zerop = b_zp;
102       b_qparam->quantized_dimension = 0;
103       _bias->quantparam(std::move(b_qparam));
104     }
105
106     _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
107     _fc->dtype(loco::DataType::S8);
108     _fc->shape(out_shape);
109     _fc->bias(_bias);
110     _fc->name("fc");
111     {
112       auto quantparam = std::make_unique<CircleQuantParam>();
113       quantparam->scale.push_back(1.0);
114       quantparam->zerop.push_back(0);
115       quantparam->quantized_dimension = 0;
116       _fc->quantparam(std::move(quantparam));
117     }
118   }
119
120 public:
121   CircleFullyConnected *_fc = nullptr;
122   CircleInput *_x = nullptr;
123   CircleConst *_weights = nullptr;
124   CircleConst *_bias = nullptr;
125 };
126
127 struct S8FCGraph final : public TestIGraphlet, public TestOGraphlet, public S8FCGraphlet
128 {
129   void init(const ShapeU32 in_shape, const ShapeU32 w_shape, const ShapeU32 out_shape,
130             const ShapeU32 bias_shape)
131   {
132     TestIGraphlet::init(g(), in_shape);
133     TestOGraphlet::init(g(), out_shape);
134     _x = input();
135     S8FCGraphlet::init(g(), out_shape, w_shape, bias_shape);
136     output()->from(_fc);
137   }
138 };
139
140 class RequantizeS8ToU8FCTest : public ::testing::Test
141 {
142 public:
143   S8FCGraph g;
144 };
145
146 } // namespace
147
148 TEST(RequantizePassTest, name)
149 {
150   luci::RequantizePass pass(loco::DataType::FLOAT32, loco::DataType::U8);
151   auto const name = pass.name();
152   ASSERT_NE(nullptr, name);
153 }
154
155 TEST_F(RequantizeS8ToU8FCTest, FC)
156 {
157   g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */,
158          {1, 256} /* bias shape*/);
159
160   luci::RequantizePass rq(loco::DataType::S8, loco::DataType::U8);
161   rq.run(g.g());
162
163   EXPECT_EQ(loco::DataType::U8, g._x->dtype());
164   EXPECT_EQ(loco::DataType::U8, g._fc->dtype());
165   EXPECT_EQ(loco::DataType::U8, g._weights->dtype());
166   EXPECT_EQ(loco::DataType::S32, g._bias->dtype());
167 }
168
169 TEST_F(RequantizeS8ToU8FCTest, FC_wrong_dtype_NEG)
170 {
171   g.init({1, 18, 80} /* ifm shape */, {256, 80} /* weights shape*/, {18, 256} /* ofm shape */,
172          {1, 256} /* bias shape*/);
173
174   // Wrong dtype
175   luci::RequantizePass rq(loco::DataType::U8, loco::DataType::S8);
176   rq.run(g.g());
177
178   EXPECT_EQ(loco::DataType::S8, g._x->dtype());
179   EXPECT_EQ(loco::DataType::S8, g._fc->dtype());
180   EXPECT_EQ(loco::DataType::S8, g._weights->dtype());
181   EXPECT_EQ(loco::DataType::S32, g._bias->dtype());
182 }