Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / record-minmax / driver / Driver.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "RecordMinMax.h"
18
19 #include <arser/arser.h>
20 #include <vconone/vconone.h>
21
22 void print_version(void)
23 {
24   std::cout << "record-minmax version " << vconone::get_string() << std::endl;
25   std::cout << vconone::get_copyright() << std::endl;
26 }
27
28 int entry(const int argc, char **argv)
29 {
30   using namespace record_minmax;
31
32   arser::Arser arser(
33       "Embedding min/max values of activations to the circle model for post-training quantization");
34
35   arser.add_argument("--version")
36       .nargs(0)
37       .required(false)
38       .default_value(false)
39       .help("Show version information and exit")
40       .exit_with(print_version);
41
42   arser.add_argument("--input_model")
43       .nargs(1)
44       .type(arser::DataType::STR)
45       .required(true)
46       .help("Input model filepath");
47
48   arser.add_argument("--input_data")
49       .nargs(1)
50       .type(arser::DataType::STR)
51       .required(true)
52       .help("Input data filepath");
53
54   arser.add_argument("--output_model")
55       .nargs(1)
56       .type(arser::DataType::STR)
57       .required(true)
58       .help("Output model filepath");
59
60   arser.add_argument("--min_percentile")
61       .nargs(1)
62       .type(arser::DataType::FLOAT)
63       .help("Record n'th percentile of min");
64
65   arser.add_argument("--max_percentile")
66       .nargs(1)
67       .type(arser::DataType::FLOAT)
68       .help("Record n'th percentile of max");
69
70   arser.add_argument("--mode")
71       .nargs(1)
72       .type(arser::DataType::STR)
73       .help("Record mode. percentile (default) or moving_average");
74
75   try
76   {
77     arser.parse(argc, argv);
78   }
79   catch (const std::runtime_error &err)
80   {
81     std::cout << err.what() << std::endl;
82     std::cout << arser;
83     return 255;
84   }
85
86   auto input_model_path = arser.get<std::string>("--input_model");
87   auto input_data_path = arser.get<std::string>("--input_data");
88   auto output_model_path = arser.get<std::string>("--output_model");
89
90   // Default values
91   std::string mode("percentile");
92   float min_percentile = 1.0;
93   float max_percentile = 99.0;
94
95   if (arser["--min_percentile"])
96     min_percentile = arser.get<float>("--min_percentile");
97
98   if (arser["--max_percentile"])
99     max_percentile = arser.get<float>("--max_percentile");
100
101   if (arser["--mode"])
102     mode = arser.get<std::string>("--mode");
103
104   if (mode != "percentile" && mode != "moving_average")
105     throw std::runtime_error("Unsupported mode");
106
107   RecordMinMax rmm;
108
109   // Initialize interpreter and observer
110   rmm.initialize(input_model_path);
111
112   // Profile min/max while executing the given input data
113   rmm.profileData(mode, input_data_path, min_percentile, max_percentile);
114
115   // Save profiled values to the model
116   rmm.saveModel(output_model_path);
117
118   return EXIT_SUCCESS;
119 }