Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / QuantizeWeightsPass.cpp
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/QuantizeWeightsPass.h"
18 #include "QuantizeWeightsOnly.h"
19 #include "QuantizationUtils.h"
20
21 #include <luci/Log.h>
22
23 namespace luci
24 {
25
26 bool QuantizeWeightsPass::run(loco::Graph *g)
27 {
28   LOGGER(l);
29   INFO(l) << "QuantizeWeightsPass Start" << std::endl;
30
31   if (_ctx->input_model_dtype != loco::DataType::FLOAT32)
32     throw std::runtime_error("Weights-only quantization supports float32 input only");
33
34   // Quantize weights
35   for (auto node : loco::active_nodes(loco::output_nodes(g)))
36   {
37     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
38     QuantizeWeightsOnly qw(_ctx->input_model_dtype, _ctx->output_model_dtype, _ctx->granularity);
39     circle_node->accept(&qw);
40   }
41
42   INFO(l) << "QuantizeWeightsPass End" << std::endl;
43   return false; // one time run
44 }
45
46 } // namespace luci