Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / tflite2circle / driver / Driver.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <iostream>
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include <arser/arser.h>
23
24 #include "CircleModel.h"
25 #include "TFLModel.h"
26
27 int entry(int argc, char **argv)
28 {
29   arser::Arser arser{"tflite2circle is a Tensorflow lite to circle model converter"};
30
31   arser.add_argument("tflite")
32       .nargs(1)
33       .type(arser::DataType::STR)
34       .help("Source tflite file path to convert");
35   arser.add_argument("circle").nargs(1).type(arser::DataType::STR).help("Target circle file path");
36
37   try
38   {
39     arser.parse(argc, argv);
40   }
41   catch (const std::runtime_error &err)
42   {
43     std::cout << err.what() << std::endl;
44     std::cout << arser;
45     return 0;
46   }
47
48   std::string tfl_path = arser.get<std::string>("tflite");
49   std::string circle_path = arser.get<std::string>("circle");
50   // read tflite file
51   tflite2circle::TFLModel tfl_model(tfl_path);
52   if (!tfl_model.is_valid())
53   {
54     std::cerr << "ERROR: Failed to load tflite '" << tfl_path << "'" << std::endl;
55     return 255;
56   }
57
58   // create flatbuffer builder
59   auto flatbuffer_builder = std::make_unique<flatbuffers::FlatBufferBuilder>(1024);
60
61   // convert tflite to circle
62   tflite2circle::CircleModel circle_model{flatbuffer_builder, tfl_model};
63
64   std::ofstream outfile{circle_path, std::ios::binary};
65
66   outfile.write(circle_model.base(), circle_model.size());
67   outfile.close();
68   // TODO find a better way of error handling
69   if (outfile.fail())
70   {
71     std::cerr << "ERROR: Failed to write circle '" << circle_path << "'" << std::endl;
72     return 255;
73   }
74
75   return 0;
76 }