2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "VerifyQuantizedBiasScale.h"
21 // This macro is undef at the end of the file
22 #define RETURN_FALSE_UNLESS(ARG) \
31 bool same(float a, float b)
33 constexpr float epsilon = 1e-10;
34 return std::abs(a - b) < epsilon;
37 // Check bias scale = input scale * weight scale
38 // This function checks both LWQ and CWQ
39 bool check_bias_scale(const loco::Node *input, const loco::Node *weights, const loco::Node *bias)
41 auto input_node = loco::must_cast<const luci::CircleNode *>(input);
42 auto input_qparam = input_node->quantparam();
43 RETURN_FALSE_UNLESS(input_qparam != nullptr);
45 auto weights_node = loco::must_cast<const luci::CircleNode *>(weights);
46 auto weights_qparam = weights_node->quantparam();
47 RETURN_FALSE_UNLESS(weights_qparam != nullptr);
49 auto bias_node = loco::must_cast<const luci::CircleNode *>(bias);
50 auto bias_qparam = bias_node->quantparam();
51 RETURN_FALSE_UNLESS(bias_qparam != nullptr);
53 RETURN_FALSE_UNLESS(input_qparam->scale.size() == 1);
54 RETURN_FALSE_UNLESS(weights_qparam->scale.size() == bias_qparam->scale.size());
56 auto input_scale = input_qparam->scale[0];
57 for (uint32_t i = 0; i < weights_qparam->scale.size(); i++)
59 auto weights_scale = weights_qparam->scale[i];
60 auto bias_scale = bias_qparam->scale[i];
61 RETURN_FALSE_UNLESS(same(bias_scale, input_scale * weights_scale));
71 bool VerifyQuantizedBiasScale::visit(const luci::CircleConv2D *node)
73 RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
77 bool VerifyQuantizedBiasScale::visit(const luci::CircleDepthwiseConv2D *node)
79 RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->filter(), node->bias()));
83 bool VerifyQuantizedBiasScale::visit(const luci::CircleFullyConnected *node)
85 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
88 RETURN_FALSE_UNLESS(check_bias_scale(node->input(), node->weights(), node->bias()));
93 bool VerifyQuantizedBiasScale::visit(const luci::CircleTransposeConv *node)
95 luci::CircleConst *bias = dynamic_cast<luci::CircleConst *>(node->bias());
98 RETURN_FALSE_UNLESS(check_bias_scale(node->outBackprop(), node->filter(), node->bias()));
105 #undef RETURN_FALSE_UNLESS