3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
22 import tflite.SubGraph
25 from operator_parser import OperatorParser
26 from subgraph_printer import SubgraphPrinter
27 from model_saver import ModelSaver
30 class TFLiteModelFileParser(object):
31 def __init__(self, args):
32 # Read flatbuffer file descriptor using argument
33 self.tflite_file = args.input_file
35 # Set print level (0 ~ 1)
36 self.print_level = args.verbose
37 if (args.verbose > 1):
39 if (args.verbose < 0):
42 # Set tensor index list to print information
43 self.print_all_tensor = True
44 if (args.tensor != None):
45 if (len(args.tensor) != 0):
46 self.print_all_tensor = False
47 self.print_tensor_index = []
48 for tensor_index in args.tensor:
49 self.print_tensor_index.append(int(tensor_index))
51 # Set operator index list to print information
52 self.print_all_operator = True
53 if (args.operator != None):
54 if (len(args.operator) != 0):
55 self.print_all_operator = False
56 self.print_operator_index = []
57 for operator_index in args.operator:
58 self.print_operator_index.append(int(operator_index))
64 self.save_config = True
67 self.save_prefix = args.prefix
69 def PrintModel(self, model_name, op_parser):
70 printer = SubgraphPrinter(self.print_level, op_parser, model_name)
72 if self.print_all_tensor == False:
73 printer.SetPrintSpecificTensors(self.print_tensor_index)
75 if self.print_all_operator == False:
76 printer.SetPrintSpecificOperators(self.print_operator_index)
80 def SaveModel(self, model_name, op_parser):
81 saver = ModelSaver(model_name, op_parser)
83 if self.save_config == True:
84 saver.SaveConfigInfo(self.save_prefix)
87 # Generate Model: top structure of tflite model file
88 buf = self.tflite_file.read()
90 tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
92 stats = graph_stats.GraphStats()
93 # Model file can have many models
94 for subgraph_index in range(tf_model.SubgraphsLength()):
95 tf_subgraph = tf_model.Subgraphs(subgraph_index)
96 model_name = "#{0} {1}".format(subgraph_index, tf_subgraph.Name())
97 # 0th subgraph is main subgraph
98 if (subgraph_index == 0):
99 model_name += " (MAIN)"
102 op_parser = OperatorParser(tf_model, tf_subgraph)
105 stats += graph_stats.CalcGraphStats(op_parser)
107 if self.save == False:
108 # print all of operators or requested objects
109 self.PrintModel(model_name, op_parser)
111 # save all of operators in this model
112 self.SaveModel(model_name, op_parser)
114 print('==== Model Stats ({} Subgraphs) ===='.format(tf_model.SubgraphsLength()))
116 graph_stats.PrintGraphStats(stats, self.print_level)
119 if __name__ == '__main__':
120 # Define argument and read
121 arg_parser = argparse.ArgumentParser()
122 arg_parser.add_argument(
123 "input_file", type=argparse.FileType('rb'), help="tflite file to read")
124 arg_parser.add_argument(
125 '-v', '--verbose', type=int, default=1, help="set print level (0~1, default: 1)")
126 arg_parser.add_argument(
127 '-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)")
128 arg_parser.add_argument(
132 help="operator ID to print information (default: all)")
133 arg_parser.add_argument(
137 help="Save the configuration file per operator")
138 arg_parser.add_argument(
139 '-p', '--prefix', help="file prefix to be saved (with -c/--config option)")
140 args = arg_parser.parse_args()
143 TFLiteModelFileParser(args).main()