Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / tfl-inspect / driver / Driver.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Dump.h"
18
19 #include <arser/arser.h>
20 #include <foder/FileLoader.h>
21
22 #include <functional>
23 #include <iostream>
24 #include <map>
25 #include <memory>
26 #include <vector>
27 #include <string>
28
29 int entry(int argc, char **argv)
30 {
31   arser::Arser arser{"tfl-inspect allows users to retrieve various information from a TensorFlow "
32                      "Lite model files"};
33   arser.add_argument("--operators").nargs(0).help("Dump operators in tflite file");
34   arser.add_argument("--conv2d_weight")
35       .nargs(0)
36       .help("Dump Conv2D series weight operators in tflite file");
37   arser.add_argument("--op_version").nargs(0).help("Dump versions of the operators in tflite file");
38   arser.add_argument("tflite").type(arser::DataType::STR).help("TFLite file to inspect");
39
40   try
41   {
42     arser.parse(argc, argv);
43   }
44   catch (const std::runtime_error &err)
45   {
46     std::cout << err.what() << std::endl;
47     std::cout << arser;
48     return 255;
49   }
50
51   if (!arser["--operators"] && !arser["--conv2d_weight"] && !arser["--op_version"])
52   {
53     std::cout << "At least one option must be specified" << std::endl;
54     std::cout << arser;
55     return 255;
56   }
57
58   std::vector<std::unique_ptr<tflinspect::DumpInterface>> dumps;
59
60   if (arser["--operators"])
61     dumps.push_back(std::make_unique<tflinspect::DumpOperators>());
62   if (arser["--conv2d_weight"])
63     dumps.push_back(std::make_unique<tflinspect::DumpConv2DWeight>());
64   if (arser["--op_version"])
65     dumps.push_back(std::make_unique<tflinspect::DumpOperatorVersion>());
66
67   std::string model_file = arser.get<std::string>("tflite");
68
69   // Load TF lite model from a tflite file
70   foder::FileLoader fileLoader{model_file};
71   std::vector<char> modelData = fileLoader.load();
72   const tflite::Model *tfliteModel = tflite::GetModel(modelData.data());
73   if (tfliteModel == nullptr)
74   {
75     std::cerr << "ERROR: Failed to load tflite '" << model_file << "'" << std::endl;
76     return 255;
77   }
78
79   for (auto &dump : dumps)
80   {
81     dump->run(std::cout, tfliteModel);
82   }
83
84   return 0;
85 }