Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / VerifyQuantizedBiasScale.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "VerifyQuantizedBiasScale.h"
18
19 #include <cmath>
20
21 // This macro is undef at the end of the file
22 #define RETURN_FALSE_UNLESS(ARG) \
23   if (not(ARG))                  \
24   {                              \
25     return false;                \
26   }
27
28 namespace
29 {
30
31 bool same(float a, float b)
32 {
33   constexpr float epsilon = 1e-10;
34   return std::abs(a - b) < epsilon;
35 }
36
37 // Check bias scale = input scale * weight scale
38 // This function checks both LWQ and CWQ
39 bool check_bias_scale(const loco::Node *input, const loco::Node *weights, const loco::Node *bias)
40 {
41   auto input_node = loco::must_cast<const luci::CircleNode *>(input);
42   auto input_qparam = input_node->quantparam();
43   RETURN_FALSE_UNLESS(input_qparam != nullptr);
44
45   auto weights_node = loco::must_cast<const luci::CircleNode *>(weights);
46   auto weights_qparam = weights_node->quantparam();
47   RETURN_FALSE_UNLESS(weights_qparam != nullptr);
48
49   auto bias_node = loco::must_cast<const luci::CircleNode *>(bias);
50   auto bias_qparam = bias_node->quantparam();
51   RETURN_FALSE_UNLESS(bias_qparam != nullptr);
52
53   RETURN_FALSE_UNLESS(input_qparam->scale.size() == 1);
54   RETURN_FALSE_UNLESS(weights_qparam->scale.size() == bias_qparam->scale.size());
55
56   auto input_scale = input_qparam->scale[0];
57   for (uint32_t i = 0; i < weights_qparam->scale.size(); i++)
58   {
59     auto weights_scale = weights_qparam->scale[i];
60     auto bias_scale = bias_qparam->scale[i];
61     RETURN_FALSE_UNLESS(same(bias_scale, input_scale * weights_scale));
62   }
63   return true;
64 }
65
66 } // namespace
67
68 namespace luci
69 {
70
71 bool VerifyQuantizedBiasScale::visit(const luci::CircleConv2D *node)
72 {
73   RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
74   return true;
75 }
76
77 bool VerifyQuantizedBiasScale::visit(const luci::CircleDepthwiseConv2D *node)
78 {
79   RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
80   return true;
81 }
82
83 bool VerifyQuantizedBiasScale::visit(const luci::CircleFullyConnected *node)
84 {
85   luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
86   if (bias != nullptr)
87   {
88     RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->weights(), node->bias()));
89   }
90   return true;
91 }
92
93 bool VerifyQuantizedBiasScale::visit(const luci::CircleTransposeConv *node)
94 {
95   luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
96   if (bias != nullptr)
97   {
98     RETURN_FALSE_UNLESS(check_bias_scale(node->outBackprop(), node->filter(), node->bias()));
99   }
100   return true;
101 }
102
103 } // namespace luci
104
105 #undef RETURN_FALSE_UNLESS