Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-opselector / driver / Driver.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ModuleIO.h"
18 #include "OpSelector.h"
19
20 #include <luci/ConnectNode.h>
21 #include <luci/Profile/CircleNodeID.h>
22 #include <luci/Service/CircleNodeClone.h>
23
24 #include <arser/arser.h>
25 #include <vconone/vconone.h>
26
27 #include <iostream>
28 #include <string>
29 #include <vector>
30 #include <algorithm>
31 #include <cctype>
32 #include <numeric>
33 #include <sstream>
34
35 void print_version(void)
36 {
37   std::cout << "circle-opselector version " << vconone::get_string() << std::endl;
38   std::cout << vconone::get_copyright() << std::endl;
39 }
40
41 int entry(int argc, char **argv)
42 {
43   // TODO Add new option names!
44
45   arser::Arser arser("circle-opselector provides selecting operations in circle model");
46
47   arser::Helper::add_version(arser, print_version);
48
49   // TODO Add new options!
50
51   arser.add_argument("input").help("Input circle model");
52   arser.add_argument("output").help("Output circle model");
53
54   // select option
55   arser.add_argument("--by_id").help("Input operation id to select nodes.");
56   arser.add_argument("--by_name").help("Input operation name to select nodes.");
57
58   try
59   {
60     arser.parse(argc, argv);
61   }
62   catch (const std::runtime_error &err)
63   {
64     std::cerr << err.what() << std::endl;
65     std::cout << arser;
66     return EXIT_FAILURE;
67   }
68
69   std::string input_path = arser.get<std::string>("input");
70   std::string output_path = arser.get<std::string>("output");
71
72   if (!arser["--by_id"] && !arser["--by_name"] || arser["--by_id"] && arser["--by_name"])
73   {
74     std::cerr << "ERROR: Either option '--by_id' or '--by_name' must be specified" << std::endl;
75     std::cerr << arser;
76     return EXIT_FAILURE;
77   }
78
79   // Import original circle file.
80   auto module = opselector::getModule(input_path);
81
82   // TODO support two or more subgraphs
83   if (module.get()->size() != 1)
84   {
85     std::cerr << "ERROR: Not support two or more subgraphs" << std::endl;
86     return EXIT_FAILURE;
87   }
88
89   opselector::OpSelector op_selector{module.get()};
90
91   std::unique_ptr<luci::Module> new_module;
92   std::string operator_input;
93
94   if (arser["--by_id"])
95   {
96     operator_input = arser.get<std::string>("--by_id");
97     new_module = op_selector.select_by<opselector::SelectType::ID>(operator_input);
98   }
99   if (arser["--by_name"])
100   {
101     operator_input = arser.get<std::string>("--by_name");
102     new_module = op_selector.select_by<opselector::SelectType::NAME>(operator_input);
103   }
104
105   if (not opselector::exportModule(new_module.get(), output_path))
106   {
107     std::cerr << "ERROR: Cannot export the module" << std::endl;
108     return EXIT_FAILURE;
109   }
110
111   return 0;
112 }