2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "QuantizeWeightsOnly.h"
18 #include "QuantizationUtils.h"
20 #include <luci/Service/Nodes/CircleConst.h>
33 using IterFunc = std::function<void(uint32_t *, loco::TensorShape &, int32_t)>;
35 void iterate_per_channel(CircleConst *node, int32_t &channel_dim_index, IterFunc func)
37 loco::TensorShape dimension;
39 uint32_t indices[4] = {
43 if (!get_channel_dim_index(node, dimension, channel_dim_index))
49 for (indices[0] = 0; indices[0] < dimension.dim(0).value(); indices[0]++)
51 for (indices[1] = 0; indices[1] < dimension.dim(1).value(); indices[1]++)
53 for (indices[2] = 0; indices[2] < dimension.dim(2).value(); indices[2]++)
55 for (indices[3] = 0; indices[3] < dimension.dim(3).value(); indices[3]++)
57 func(indices, dimension, channel_dim_index);
64 // TODO Reduce duplicate code with QuantizeDequantizeWeights
65 template <loco::DataType out_type>
66 void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
67 std::vector<float> &scaling_factor, std::vector<float> &nudged_min,
68 std::vector<float> &nudged_max, int32_t &channel_dim_index)
70 assert(node->dtype() == loco::DataType::FLOAT32);
71 assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16);
72 const int32_t kMaxScale = (out_type == loco::DataType::S8) ? std::numeric_limits<int8_t>::max()
73 : std::numeric_limits<int16_t>::max();
74 const int32_t kMinScale = -kMaxScale;
76 uint32_t size = node->size<loco::DataType::FLOAT32>();
77 std::vector<int32_t> quantized_values(size);
79 for (size_t i = 0; i < min.size(); ++i)
81 compute_sym_scale(min[i], max[i], scaling_factor[i], nudged_min[i], nudged_max[i], out_type);
84 auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
85 int channel_idx = indices[channel_dim_index];
86 const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
87 auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
88 data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
89 data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
90 quantized_values[cal_offset(dimension, indices)] =
91 static_cast<int32_t>(std::round(data * scaling_factor_inv));
94 iterate_per_channel(node, channel_dim_index, quantize);
96 node->dtype(out_type); // change the type of tensor
97 node->size<out_type>(size); // resize tensor
98 for (uint32_t i = 0; i < size; ++i)
100 node->at<out_type>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
104 void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
105 int32_t &channel_dim_index)
107 loco::TensorShape dimension;
110 if (!get_channel_dim_index(node, dimension, channel_dim_index))
112 throw std::runtime_error("Failed to find channel index in " + node->name());
114 auto size = dimension.dim(channel_dim_index).value();
116 std::vector<bool> has_min_max_value(size, false);
120 auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
121 int channel_idx = indices[channel_dim_index];
122 auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
123 if (has_min_max_value[channel_idx])
125 min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx];
126 max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx];
130 min[channel_idx] = data;
131 max[channel_idx] = data;
132 has_min_max_value[channel_idx] = true;
136 iterate_per_channel(node, channel_dim_index, cal_minmax);
144 void QuantizeWeightsOnly::quantize_weights(luci::CircleConst *weights)
146 // Find min/max per channel-wise
147 if (granularity == QuantizationGranularity::ChannelWise)
149 auto quantparam = weights->quantparam();
150 if (quantparam == nullptr)
152 // Find min/max on the fly
153 // NOTE This is for the case when QuantizeDequantizeWeights is skipped
154 // TODO Reduce duplicate codes
155 std::vector<float> min;
156 std::vector<float> max;
157 int32_t channel_dim_index = 0;
159 cal_minmax_per_channel(weights, min, max, channel_dim_index);
161 std::vector<float> nudged_min(min.size());
162 std::vector<float> nudged_max(min.size());
163 std::vector<float> scaling_factor(min.size());
164 std::vector<int64_t> zp(min.size());
166 if (output_type == loco::DataType::S8)
168 sym_wquant_per_channel<loco::DataType::S8>(weights, min, max, scaling_factor, nudged_min,
169 nudged_max, channel_dim_index);
171 else if (output_type == loco::DataType::S16)
173 sym_wquant_per_channel<loco::DataType::S16>(weights, min, max, scaling_factor, nudged_min,
174 nudged_max, channel_dim_index);
178 throw std::runtime_error("Weights-only quantization supports s8 and s16");
181 auto quantparam = std::make_unique<CircleQuantParam>();
182 quantparam->scale = scaling_factor;
183 quantparam->zerop = zp;
184 quantparam->quantized_dimension = channel_dim_index;
185 weights->quantparam(std::move(quantparam));
191 throw std::runtime_error("Weights-only quantization does not support layer-wise");
194 void QuantizeWeightsOnly::visit(luci::CircleConv2D *node)
197 INFO(l) << "QuantizeWeightsOnly visits node: " << node->name() << std::endl;
199 auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
200 if (!is_quantized(weights))
202 auto new_weights = luci::clone(weights);
203 node->filter(new_weights);
204 quantize_weights(new_weights);
208 void QuantizeWeightsOnly::visit(luci::CircleDepthwiseConv2D *node)
211 INFO(l) << "QuantizeWeightsOnly visits node: " << node->name() << std::endl;
213 auto weights = loco::must_cast<luci::CircleConst *>(node->filter());
214 if (!is_quantized(weights))
216 auto new_weights = luci::clone(weights);
217 node->filter(new_weights);
218 quantize_weights(new_weights);
222 void QuantizeWeightsOnly::visit(luci::CircleNode *) {}