Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / runtime / onert / odc / Quantizer.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Quantizer.h"
18
19 #include <luci/ImporterEx.h>
20 #include <luci/CircleQuantizer.h>
21 #include <luci/CircleExporter.h>
22 #include <luci/CircleFileExpContract.h>
23
24 #include <iostream>
25
26 extern "C" onert::odc::IQuantizer *create_quantizer() { return new onert::odc::Quantizer(); }
27 extern "C" void destroy_quantizer(onert::odc::IQuantizer *quantizer) { delete quantizer; }
28
29 namespace onert
30 {
31 namespace odc
32 {
33
34 int Quantizer::quantize(const char *in, const char *out, bool is_q16)
35 {
36   // Load model from the file
37   luci::ImporterEx importerex;
38   auto module = importerex.importVerifyModule(std::string(in));
39   if (module.get() == nullptr)
40     return 1;
41
42   luci::CircleQuantizer quantizer;
43   auto options = quantizer.options();
44   {
45     options->enable(luci::CircleQuantizer::Options::Algorithm::QuantizeWeights);
46
47     using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
48     options->param(AlgorithmParameters::Quantize_input_model_dtype, "float32");
49     options->param(AlgorithmParameters::Quantize_output_model_dtype, is_q16 ? "int16" : "int8");
50     options->param(AlgorithmParameters::Quantize_granularity, "channel");
51   }
52
53   for (size_t idx = 0; idx < module->size(); ++idx)
54   {
55     auto graph = module->graph(idx);
56
57     // quantize the graph
58     quantizer.quantize(graph);
59
60     // Skip validate
61     // TODO Validate if needed
62 #if 0
63     if (!luci::validate(graph))
64     {
65       std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
66       return 1;
67     }
68 #endif
69   }
70
71   // Export to output Circle file
72   luci::CircleExporter exporter;
73   luci::CircleFileExpContract contract(module.get(), std::string(out));
74
75   if (!exporter.invoke(&contract))
76     return 1;
77
78   // Return 0 when luci::CircleQuantizer::Options::Algorithm::QuantizeWeights is ready
79   return 0;
80 }
81
82 } // namespace odc
83 } // namespace onert