849597b4699869106f2e926b3b0582f262850aa9
[platform/core/ml/nnfw.git] / compiler / circle2circle / src / Circle2Circle.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "CircleExpContract.h"
18
19 #include <foder/FileLoader.h>
20
21 #include <luci/Importer.h>
22 #include <luci/CircleOptimizer.h>
23 #include <luci/Service/Validate.h>
24 #include <luci/CircleExporter.h>
25 #include <luci/UserSettings.h>
26
27 #include <oops/InternalExn.h>
28 #include <arser/arser.h>
29 #include <vconone/vconone.h>
30
31 #include <functional>
32 #include <iostream>
33 #include <string>
34
35 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
36 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
37
38 void print_version(void)
39 {
40   std::cout << "circle2circle version " << vconone::get_string() << std::endl;
41   std::cout << vconone::get_copyright() << std::endl;
42 }
43
44 int entry(int argc, char **argv)
45 {
46   // Simple argument parser (based on map)
47   luci::CircleOptimizer optimizer;
48
49   auto options = optimizer.options();
50   auto settings = luci::UserSettings::settings();
51
52   arser::Arser arser("circle2circle provides circle model optimization and transformations");
53
54   arser.add_argument("--version")
55       .nargs(0)
56       .required(false)
57       .default_value(false)
58       .help("Show version information and exit")
59       .exit_with(print_version);
60
61   arser.add_argument("--all").nargs(0).required(false).default_value(false).help(
62       "Enable all optimize options");
63
64   arser.add_argument("--fuse_bcq")
65       .nargs(0)
66       .required(false)
67       .default_value(false)
68       .help("This will fuse operators and apply Binary Coded Quantization");
69
70   arser.add_argument("--fuse_instnorm")
71       .nargs(0)
72       .required(false)
73       .default_value(false)
74       .help("This will fuse operators to InstanceNorm operator");
75
76   arser.add_argument("--resolve_customop_add")
77       .nargs(0)
78       .required(false)
79       .default_value(false)
80       .help("This will convert Custom(Add) to Add operator");
81
82   arser.add_argument("--resolve_customop_batchmatmul")
83       .nargs(0)
84       .required(false)
85       .default_value(false)
86       .help("This will convert Custom(BatchMatmul) to BatchMatmul operator");
87
88   arser.add_argument("--resolve_customop_matmul")
89       .nargs(0)
90       .required(false)
91       .default_value(false)
92       .help("This will convert Custom(Matmul) to Matmul operator");
93
94   arser.add_argument("--mute_warnings")
95       .nargs(0)
96       .required(false)
97       .default_value(false)
98       .help("This will turn off warning messages");
99
100   arser.add_argument("--disable_validation")
101       .nargs(0)
102       .required(false)
103       .default_value(false)
104       .help("This will turn off operator vaidations. May help input model investigation.");
105
106   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
107   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
108
109   try
110   {
111     arser.parse(argc, argv);
112   }
113   catch (const std::runtime_error &err)
114   {
115     std::cout << err.what() << std::endl;
116     std::cout << arser;
117     return 255;
118   }
119
120   if (arser.get<bool>("--all"))
121   {
122     options->enable(Algorithms::FuseBCQ);
123     options->enable(Algorithms::FuseInstanceNorm);
124     options->enable(Algorithms::ResolveCustomOpAdd);
125     options->enable(Algorithms::ResolveCustomOpBatchMatMul);
126     options->enable(Algorithms::ResolveCustomOpMatMul);
127   }
128   if (arser.get<bool>("--fuse_bcq"))
129     options->enable(Algorithms::FuseBCQ);
130   if (arser.get<bool>("--fuse_instnorm"))
131     options->enable(Algorithms::FuseInstanceNorm);
132   if (arser.get<bool>("--resolve_customop_add"))
133     options->enable(Algorithms::ResolveCustomOpAdd);
134   if (arser.get<bool>("--resolve_customop_batchmatmul"))
135     options->enable(Algorithms::ResolveCustomOpBatchMatMul);
136   if (arser.get<bool>("--resolve_customop_matmul"))
137     options->enable(Algorithms::ResolveCustomOpMatMul);
138
139   if (arser.get<bool>("--mute_warnings"))
140     settings->set(luci::UserSettings::Key::MuteWarnings, true);
141   if (arser.get<bool>("--disable_validation"))
142     settings->set(luci::UserSettings::Key::DisableValidation, true);
143
144   std::string input_path = arser.get<std::string>("input");
145   std::string output_path = arser.get<std::string>("output");
146
147   // Load model from the file
148   foder::FileLoader file_loader{input_path};
149   std::vector<char> model_data;
150
151   try
152   {
153     model_data = file_loader.load();
154   }
155   catch (const std::runtime_error &err)
156   {
157     std::cerr << err.what() << std::endl;
158     return EXIT_FAILURE;
159   }
160   const circle::Model *circle_model = circle::GetModel(model_data.data());
161   if (circle_model == nullptr)
162   {
163     std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
164     return EXIT_FAILURE;
165   }
166
167   // Import from input Circle file
168   luci::Importer importer;
169   auto module = importer.importModule(circle_model);
170
171   for (size_t idx = 0; idx < module->size(); ++idx)
172   {
173     auto graph = module->graph(idx);
174
175     // call luci optimizations
176     optimizer.optimize(graph);
177
178     if (!luci::validate(graph))
179     {
180       std::cerr << "ERROR: Optimized graph is invalid" << std::endl;
181       return 255;
182     }
183   }
184
185   // Export to output Circle file
186   luci::CircleExporter exporter;
187
188   CircleExpContract contract(module.get(), output_path);
189
190   if (!exporter.invoke(&contract))
191   {
192     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
193     return 255;
194   }
195
196   return 0;
197 }