Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / record-minmax / driver / Driver.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "RecordMinMax.h"
18
19 #include <arser/arser.h>
20
21 int entry(const int argc, char **argv)
22 {
23   using namespace record_minmax;
24
25   arser::Arser arser(
26       "Embedding min/max values of activations to the circle model for post-training quantization");
27
28   arser.add_argument("--input_model")
29       .nargs(1)
30       .type(arser::DataType::STR)
31       .required(true)
32       .help("Input model filepath");
33
34   arser.add_argument("--input_data")
35       .nargs(1)
36       .type(arser::DataType::STR)
37       .required(true)
38       .help("Input data filepath");
39
40   arser.add_argument("--output_model")
41       .nargs(1)
42       .type(arser::DataType::STR)
43       .required(true)
44       .help("Output model filepath");
45
46   arser.add_argument("--min_percentile")
47       .nargs(1)
48       .type(arser::DataType::FLOAT)
49       .help("Record n'th percentile of min");
50
51   arser.add_argument("--max_percentile")
52       .nargs(1)
53       .type(arser::DataType::FLOAT)
54       .help("Record n'th percentile of max");
55
56   arser.add_argument("--mode")
57       .nargs(1)
58       .type(arser::DataType::STR)
59       .help("Record mode. percentile (default) or moving_average");
60
61   try
62   {
63     arser.parse(argc, argv);
64   }
65   catch (const std::runtime_error &err)
66   {
67     std::cout << err.what() << std::endl;
68     std::cout << arser;
69     return 0;
70   }
71
72   auto input_model_path = arser.get<std::string>("--input_model");
73   auto input_data_path = arser.get<std::string>("--input_data");
74   auto output_model_path = arser.get<std::string>("--output_model");
75
76   // Default values
77   std::string mode("percentile");
78   float min_percentile = 1.0;
79   float max_percentile = 99.0;
80
81   if (arser["--min_percentile"])
82     min_percentile = arser.get<float>("--min_percentile");
83
84   if (arser["--max_percentile"])
85     max_percentile = arser.get<float>("--max_percentile");
86
87   if (arser["--mode"])
88     mode = arser.get<std::string>("--mode");
89
90   if (mode != "percentile" && mode != "moving_average")
91     throw std::runtime_error("Unsupported mode");
92
93   RecordMinMax rmm;
94
95   // Initialize interpreter and observer
96   rmm.initialize(input_model_path);
97
98   // Profile min/max while executing the given input data
99   rmm.profileData(mode, input_data_path, min_percentile, max_percentile);
100
101   // Save profiled values to the model
102   rmm.saveModel(output_model_path);
103
104   return EXIT_SUCCESS;
105 }