Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / tflite2circle / driver / Driver.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <iostream>
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include <arser/arser.h>
23
24 #include "CircleModel.h"
25 #include "TFLModel.h"
26
27 #include <vconone/vconone.h>
28
29 void print_version(void)
30 {
31   std::cout << "tflite2circle version " << vconone::get_string() << std::endl;
32   std::cout << vconone::get_copyright() << std::endl;
33 }
34
35 int entry(int argc, char **argv)
36 {
37   arser::Arser arser{"tflite2circle is a Tensorflow lite to circle model converter"};
38
39   arser.add_argument("--version")
40       .nargs(0)
41       .required(false)
42       .default_value(false)
43       .help("Show version information and exit")
44       .exit_with(print_version);
45
46   arser.add_argument("tflite")
47       .nargs(1)
48       .type(arser::DataType::STR)
49       .help("Source tflite file path to convert");
50   arser.add_argument("circle").nargs(1).type(arser::DataType::STR).help("Target circle file path");
51
52   try
53   {
54     arser.parse(argc, argv);
55   }
56   catch (const std::runtime_error &err)
57   {
58     std::cout << err.what() << std::endl;
59     std::cout << arser;
60     return 255;
61   }
62
63   std::string tfl_path = arser.get<std::string>("tflite");
64   std::string circle_path = arser.get<std::string>("circle");
65   // read tflite file
66   tflite2circle::TFLModel tfl_model(tfl_path);
67   if (!tfl_model.is_valid())
68   {
69     std::cerr << "ERROR: Failed to load tflite '" << tfl_path << "'" << std::endl;
70     return 255;
71   }
72
73   // create flatbuffer builder
74   auto flatbuffer_builder = std::make_unique<flatbuffers::FlatBufferBuilder>(1024);
75
76   // convert tflite to circle
77   tflite2circle::CircleModel circle_model{flatbuffer_builder, tfl_model};
78
79   std::ofstream outfile{circle_path, std::ios::binary};
80
81   outfile.write(circle_model.base(), circle_model.size());
82   outfile.close();
83   // TODO find a better way of error handling
84   if (outfile.fail())
85   {
86     std::cerr << "ERROR: Failed to write circle '" << circle_path << "'" << std::endl;
87     return 255;
88   }
89
90   return 0;
91 }