2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "MinMaxComputer.h"
18 #include "RecordFunction.h"
20 #include <luci/IR/CircleQuantParam.h>
22 namespace record_minmax
25 void PercentileComputer::update_qparam(
26 const std::unordered_map<const luci::CircleNode *, MinMaxVectors> *minmax_map)
28 if (minmax_map == nullptr)
29 throw std::invalid_argument("minmax_map is nullptr");
31 for (auto iter = minmax_map->begin(); iter != minmax_map->end(); ++iter)
33 auto node = iter->first;
34 auto minmax = iter->second;
36 auto min = getNthPercentile(minmax.min_vector, _min_percentile);
37 auto max = getNthPercentile(minmax.max_vector, _max_percentile);
39 auto quantparam = std::make_unique<luci::CircleQuantParam>();
40 quantparam->min.push_back(min);
41 quantparam->max.push_back(max);
43 assert(node->quantparam() == nullptr);
45 auto mutable_node = const_cast<luci::CircleNode *>(node);
46 mutable_node->quantparam(std::move(quantparam));
50 void MovingAvgComputer::update_qparam(
51 const std::unordered_map<const luci::CircleNode *, MinMaxVectors> *minmax_map)
53 if (minmax_map == nullptr)
54 throw std::invalid_argument("minmax_map is nullptr");
56 for (auto iter = minmax_map->begin(); iter != minmax_map->end(); ++iter)
58 auto node = iter->first;
59 auto minmax = iter->second;
61 auto min = getMovingAverage(minmax.min_vector, 1 - _update_const, _batch_size, true);
62 auto max = getMovingAverage(minmax.max_vector, 1 - _update_const, _batch_size, false);
64 auto quantparam = std::make_unique<luci::CircleQuantParam>();
65 quantparam->min.push_back(min);
66 quantparam->max.push_back(max);
68 assert(node->quantparam() == nullptr);
70 auto mutable_node = const_cast<luci::CircleNode *>(node);
71 mutable_node->quantparam(std::move(quantparam));
75 std::unique_ptr<MinMaxComputer> make_percentile_computer(float min_percentile, float max_percentile)
77 return std::make_unique<PercentileComputer>(min_percentile, max_percentile);
80 std::unique_ptr<MinMaxComputer> make_moving_avg_computer(uint32_t batch_size,
81 float moving_avg_const)
83 return std::make_unique<MovingAvgComputer>(batch_size, moving_avg_const);
86 } // namespace record_minmax