Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / circle2circle / src / Circle2Circle.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <foder/FileLoader.h>
18
19 #include <luci/Importer.h>
20 #include <luci/CircleOptimizer.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24 #include <luci/UserSettings.h>
25
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
29
30 #include <functional>
31 #include <iostream>
32 #include <string>
33
34 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
35 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
36
37 void print_version(void)
38 {
39   std::cout << "circle2circle version " << vconone::get_string() << std::endl;
40   std::cout << vconone::get_copyright() << std::endl;
41 }
42
43 int entry(int argc, char **argv)
44 {
45   // Simple argument parser (based on map)
46   luci::CircleOptimizer optimizer;
47
48   auto options = optimizer.options();
49   auto settings = luci::UserSettings::settings();
50
51   arser::Arser arser("circle2circle provides circle model optimization and transformations");
52
53   arser.add_argument("--version")
54       .nargs(0)
55       .required(false)
56       .default_value(false)
57       .help("Show version information and exit")
58       .exit_with(print_version);
59
60   arser.add_argument("--all").nargs(0).required(false).default_value(false).help(
61       "Enable all optimize options");
62
63   arser.add_argument("--fuse_batchnorm_with_tconv")
64       .nargs(0)
65       .required(false)
66       .default_value(false)
67       .help("This will fuse BatchNorm operators to Transposed Convolution operator");
68
69   arser.add_argument("--fuse_bcq")
70       .nargs(0)
71       .required(false)
72       .default_value(false)
73       .help("This will fuse operators and apply Binary Coded Quantization");
74
75   arser.add_argument("--fuse_instnorm")
76       .nargs(0)
77       .required(false)
78       .default_value(false)
79       .help("This will fuse operators to InstanceNorm operator");
80
81   arser.add_argument("--resolve_customop_add")
82       .nargs(0)
83       .required(false)
84       .default_value(false)
85       .help("This will convert Custom(Add) to Add operator");
86
87   arser.add_argument("--resolve_customop_batchmatmul")
88       .nargs(0)
89       .required(false)
90       .default_value(false)
91       .help("This will convert Custom(BatchMatmul) to BatchMatmul operator");
92
93   arser.add_argument("--resolve_customop_matmul")
94       .nargs(0)
95       .required(false)
96       .default_value(false)
97       .help("This will convert Custom(Matmul) to Matmul operator");
98
99   arser.add_argument("--mute_warnings")
100       .nargs(0)
101       .required(false)
102       .default_value(false)
103       .help("This will turn off warning messages");
104
105   arser.add_argument("--disable_validation")
106       .nargs(0)
107       .required(false)
108       .default_value(false)
109       .help("This will turn off operator validations. May help input model investigation.");
110
111   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
112   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
113
114   try
115   {
116     arser.parse(argc, argv);
117   }
118   catch (const std::runtime_error &err)
119   {
120     std::cout << err.what() << std::endl;
121     std::cout << arser;
122     return 255;
123   }
124
125   if (arser.get<bool>("--all"))
126   {
127     options->enable(Algorithms::FuseBCQ);
128     options->enable(Algorithms::FuseInstanceNorm);
129     options->enable(Algorithms::ResolveCustomOpAdd);
130     options->enable(Algorithms::ResolveCustomOpBatchMatMul);
131     options->enable(Algorithms::ResolveCustomOpMatMul);
132   }
133   if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
134     options->enable(Algorithms::FuseBatchNormWithTConv);
135   if (arser.get<bool>("--fuse_bcq"))
136     options->enable(Algorithms::FuseBCQ);
137   if (arser.get<bool>("--fuse_instnorm"))
138     options->enable(Algorithms::FuseInstanceNorm);
139   if (arser.get<bool>("--resolve_customop_add"))
140     options->enable(Algorithms::ResolveCustomOpAdd);
141   if (arser.get<bool>("--resolve_customop_batchmatmul"))
142     options->enable(Algorithms::ResolveCustomOpBatchMatMul);
143   if (arser.get<bool>("--resolve_customop_matmul"))
144     options->enable(Algorithms::ResolveCustomOpMatMul);
145
146   if (arser.get<bool>("--mute_warnings"))
147     settings->set(luci::UserSettings::Key::MuteWarnings, true);
148   if (arser.get<bool>("--disable_validation"))
149     settings->set(luci::UserSettings::Key::DisableValidation, true);
150
151   std::string input_path = arser.get<std::string>("input");
152   std::string output_path = arser.get<std::string>("output");
153
154   // Load model from the file
155   foder::FileLoader file_loader{input_path};
156   std::vector<char> model_data;
157
158   try
159   {
160     model_data = file_loader.load();
161   }
162   catch (const std::runtime_error &err)
163   {
164     std::cerr << err.what() << std::endl;
165     return EXIT_FAILURE;
166   }
167
168   flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
169   if (!circle::VerifyModelBuffer(verifier))
170   {
171     std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
172     return EXIT_FAILURE;
173   }
174
175   const circle::Model *circle_model = circle::GetModel(model_data.data());
176   if (circle_model == nullptr)
177   {
178     std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
179     return EXIT_FAILURE;
180   }
181
182   // Import from input Circle file
183   luci::Importer importer;
184   auto module = importer.importModule(circle_model);
185
186   for (size_t idx = 0; idx < module->size(); ++idx)
187   {
188     auto graph = module->graph(idx);
189
190     // call luci optimizations
191     optimizer.optimize(graph);
192
193     if (!luci::validate(graph))
194     {
195       if (settings->get(luci::UserSettings::Key::DisableValidation))
196         std::cerr << "WARNING: Optimized graph is invalid" << std::endl;
197       else
198       {
199         std::cerr << "ERROR: Optimized graph is invalid" << std::endl;
200         return 255;
201       }
202     }
203   }
204
205   // Export to output Circle file
206   luci::CircleExporter exporter;
207
208   luci::CircleFileExpContract contract(module.get(), output_path);
209
210   if (!exporter.invoke(&contract))
211   {
212     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
213     return 255;
214   }
215
216   return 0;
217 }