cd66bf500b6ac9721e5ba3693be793fc339d1eb4
[platform/core/ml/nnfw.git] / tools / tflitefile_tool / model_parser.py
1 #!/usr/bin/python
2
3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import os
18 import sys
19 import numpy
20 import flatbuffers
21 import tflite.Model
22 import tflite.SubGraph
23 import argparse
24 import graph_stats
25 from operator_parser import OperatorParser
26 from subgraph_printer import SubgraphPrinter
27 from model_saver import ModelSaver
28
29
30 class TFLiteModelFileParser(object):
31     def __init__(self, args):
32         # Read flatbuffer file descriptor using argument
33         self.tflite_file = args.input_file
34
35         # Set print level (0 ~ 1)
36         self.print_level = args.verbose
37         if (args.verbose > 1):
38             self.print_level = 1
39         if (args.verbose < 0):
40             self.print_level = 0
41
42         # Set tensor index list to print information
43         self.print_all_tensor = True
44         if (args.tensor != None):
45             if (len(args.tensor) != 0):
46                 self.print_all_tensor = False
47                 self.print_tensor_index = []
48                 for tensor_index in args.tensor:
49                     self.print_tensor_index.append(int(tensor_index))
50
51         # Set operator index list to print information
52         self.print_all_operator = True
53         if (args.operator != None):
54             if (len(args.operator) != 0):
55                 self.print_all_operator = False
56                 self.print_operator_index = []
57                 for operator_index in args.operator:
58                     self.print_operator_index.append(int(operator_index))
59
60         # Set config option
61         self.save = False
62         if args.config:
63             self.save = True
64             self.save_config = True
65
66         if self.save == True:
67             self.save_prefix = args.prefix
68
69     def PrintModel(self, model_name, op_parser):
70         printer = SubgraphPrinter(self.print_level, op_parser, model_name)
71
72         if self.print_all_tensor == False:
73             printer.SetPrintSpecificTensors(self.print_tensor_index)
74
75         if self.print_all_operator == False:
76             printer.SetPrintSpecificOperators(self.print_operator_index)
77
78         printer.PrintInfo()
79
80     def SaveModel(self, model_name, op_parser):
81         saver = ModelSaver(model_name, op_parser)
82
83         if self.save_config == True:
84             saver.SaveConfigInfo(self.save_prefix)
85
86     def main(self):
87         # Generate Model: top structure of tflite model file
88         buf = self.tflite_file.read()
89         buf = bytearray(buf)
90         tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
91
92         stats = graph_stats.GraphStats()
93         # Model file can have many models
94         for subgraph_index in range(tf_model.SubgraphsLength()):
95             tf_subgraph = tf_model.Subgraphs(subgraph_index)
96             model_name = "#{0} {1}".format(subgraph_index, tf_subgraph.Name())
97             # 0th subgraph is main subgraph
98             if (subgraph_index == 0):
99                 model_name += " (MAIN)"
100
101             # Parse Operators
102             op_parser = OperatorParser(tf_model, tf_subgraph)
103             op_parser.Parse()
104
105             stats += graph_stats.CalcGraphStats(op_parser)
106
107             if self.save == False:
108                 # print all of operators or requested objects
109                 self.PrintModel(model_name, op_parser)
110             else:
111                 # save all of operators in this model
112                 self.SaveModel(model_name, op_parser)
113
114         print('==== Model Stats ({} Subgraphs) ===='.format(tf_model.SubgraphsLength()))
115         print('')
116         graph_stats.PrintGraphStats(stats, self.print_level)
117
118
119 if __name__ == '__main__':
120     # Define argument and read
121     arg_parser = argparse.ArgumentParser()
122     arg_parser.add_argument(
123         "input_file", type=argparse.FileType('rb'), help="tflite file to read")
124     arg_parser.add_argument(
125         '-v', '--verbose', type=int, default=1, help="set print level (0~1, default: 1)")
126     arg_parser.add_argument(
127         '-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)")
128     arg_parser.add_argument(
129         '-o',
130         '--operator',
131         nargs='*',
132         help="operator ID to print information (default: all)")
133     arg_parser.add_argument(
134         '-c',
135         '--config',
136         action='store_true',
137         help="Save the configuration file per operator")
138     arg_parser.add_argument(
139         '-p', '--prefix', help="file prefix to be saved (with -c/--config option)")
140     args = arg_parser.parse_args()
141
142     # Call main function
143     TFLiteModelFileParser(args).main()