Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / tflite2circle / driver / Driver.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <iostream>
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include <arser/arser.h>
23
24 #include "CircleModel.h"
25 #include "TFLModel.h"
26
27 #include <vconone/vconone.h>
28
29 void print_version(void)
30 {
31   std::cout << "tflite2circle version " << vconone::get_string() << std::endl;
32   std::cout << vconone::get_copyright() << std::endl;
33 }
34
35 int entry(int argc, char **argv)
36 {
37   arser::Arser arser{"tflite2circle is a Tensorflow lite to circle model converter"};
38
39   arser.add_argument("--version")
40     .nargs(0)
41     .required(false)
42     .default_value(false)
43     .help("Show version information and exit")
44     .exit_with(print_version);
45
46   arser.add_argument("-V", "--verbose")
47     .nargs(0)
48     .required(false)
49     .default_value(false)
50     .help("output additional information to stdout or stderr");
51
52   arser.add_argument("tflite")
53     .nargs(1)
54     .type(arser::DataType::STR)
55     .help("Source tflite file path to convert");
56   arser.add_argument("circle").nargs(1).type(arser::DataType::STR).help("Target circle file path");
57
58   try
59   {
60     arser.parse(argc, argv);
61   }
62   catch (const std::runtime_error &err)
63   {
64     std::cerr << err.what() << std::endl;
65     std::cout << arser;
66     return 255;
67   }
68
69   std::string tfl_path = arser.get<std::string>("tflite");
70   std::string circle_path = arser.get<std::string>("circle");
71   // read tflite file
72   tflite2circle::TFLModel tfl_model(tfl_path);
73   if (not tfl_model.verify_data())
74   {
75     std::cerr << "ERROR: Failed to verify tflite '" << tfl_path << "'" << std::endl;
76     return 255;
77   }
78
79   // create flatbuffer builder
80   auto flatbuffer_builder = std::make_unique<flatbuffers::FlatBufferBuilder>(1024);
81
82   // convert tflite to circle
83   tflite2circle::CircleModel circle_model{flatbuffer_builder, tfl_model.get_model()};
84
85   std::ofstream outfile{circle_path, std::ios::binary};
86
87   outfile.write(circle_model.base(), circle_model.size());
88   outfile.close();
89   // TODO find a better way of error handling
90   if (outfile.fail())
91   {
92     std::cerr << "ERROR: Failed to write circle '" << circle_path << "'" << std::endl;
93     return 255;
94   }
95
96   return 0;
97 }