4ef2374cf592d8b6131e57ff589975f96fed0146
[platform/core/ml/nnfw.git] / tools / tflitefile_tool / model_parser.py
1 #!/usr/bin/python
2
3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import os
18 import sys
19 import numpy
20
21 sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tflite'))
22 flatbuffersPath = '../../externals/flatbuffers'
23 sys.path.append(
24     os.path.join(os.path.dirname(os.path.abspath(__file__)), flatbuffersPath + '/python'))
25
26 import flatbuffers
27 import tflite.Model
28 import tflite.SubGraph
29 import argparse
30 import graph_stats
31 from operator_parser import OperatorParser
32 from subgraph_printer import SubgraphPrinter
33 from model_saver import ModelSaver
34
35
36 class TFLiteModelFileParser(object):
37     def __init__(self, args):
38         # Read flatbuffer file descriptor using argument
39         self.tflite_file = args.input_file
40
41         # Set print level (0 ~ 1)
42         self.print_level = args.verbose
43         if (args.verbose > 1):
44             self.print_level = 1
45         if (args.verbose < 0):
46             self.print_level = 0
47
48         # Set tensor index list to print information
49         self.print_all_tensor = True
50         if (args.tensor != None):
51             if (len(args.tensor) != 0):
52                 self.print_all_tensor = False
53                 self.print_tensor_index = []
54                 for tensor_index in args.tensor:
55                     self.print_tensor_index.append(int(tensor_index))
56
57         # Set operator index list to print information
58         self.print_all_operator = True
59         if (args.operator != None):
60             if (len(args.operator) != 0):
61                 self.print_all_operator = False
62                 self.print_operator_index = []
63                 for operator_index in args.operator:
64                     self.print_operator_index.append(int(operator_index))
65
66         # Set config option
67         self.save = False
68         if args.config:
69             self.save = True
70             self.save_config = True
71
72         if self.save == True:
73             self.save_prefix = args.prefix
74
75     def PrintModel(self, model_name, op_parser):
76         printer = SubgraphPrinter(self.print_level, op_parser, model_name)
77
78         if self.print_all_tensor == False:
79             printer.SetPrintSpecificTensors(self.print_tensor_index)
80
81         if self.print_all_operator == False:
82             printer.SetPrintSpecificOperators(self.print_operator_index)
83
84         printer.PrintInfo()
85
86     def SaveModel(self, model_name, op_parser):
87         saver = ModelSaver(model_name, op_parser)
88
89         if self.save_config == True:
90             saver.SaveConfigInfo(self.save_prefix)
91
92     def main(self):
93         # Generate Model: top structure of tflite model file
94         buf = self.tflite_file.read()
95         buf = bytearray(buf)
96         tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
97
98         stats = graph_stats.GraphStats()
99         # Model file can have many models
100         for subgraph_index in range(tf_model.SubgraphsLength()):
101             tf_subgraph = tf_model.Subgraphs(subgraph_index)
102             model_name = "#{0} {1}".format(subgraph_index, tf_subgraph.Name())
103             # 0th subgraph is main subgraph
104             if (subgraph_index == 0):
105                 model_name += " (MAIN)"
106
107             # Parse Operators
108             op_parser = OperatorParser(tf_model, tf_subgraph)
109             op_parser.Parse()
110
111             stats += graph_stats.CalcGraphStats(op_parser)
112
113             if self.save == False:
114                 # print all of operators or requested objects
115                 self.PrintModel(model_name, op_parser)
116             else:
117                 # save all of operators in this model
118                 self.SaveModel(model_name, op_parser)
119
120         print('==== Model Stats ({} Subgraphs) ===='.format(tf_model.SubgraphsLength()))
121         print('')
122         graph_stats.PrintGraphStats(stats, self.print_level)
123
124
125 if __name__ == '__main__':
126     # Define argument and read
127     arg_parser = argparse.ArgumentParser()
128     arg_parser.add_argument(
129         "input_file", type=argparse.FileType('rb'), help="tflite file to read")
130     arg_parser.add_argument(
131         '-v', '--verbose', type=int, default=1, help="set print level (0~1, default: 1)")
132     arg_parser.add_argument(
133         '-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)")
134     arg_parser.add_argument(
135         '-o',
136         '--operator',
137         nargs='*',
138         help="operator ID to print information (default: all)")
139     arg_parser.add_argument(
140         '-c',
141         '--config',
142         action='store_true',
143         help="Save the configuration file per operator")
144     arg_parser.add_argument(
145         '-p', '--prefix', help="file prefix to be saved (with -c/--config option)")
146     args = arg_parser.parse_args()
147
148     # Call main function
149     TFLiteModelFileParser(args).main()