3 # Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
21 sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tflite'))
22 flatbuffersPath = '../../externals/flatbuffers'
24 os.path.join(os.path.dirname(os.path.abspath(__file__)), flatbuffersPath + '/python'))
28 import tflite.SubGraph
31 from operator_parser import OperatorParser
32 from subgraph_printer import SubgraphPrinter
33 from model_saver import ModelSaver
36 class TFLiteModelFileParser(object):
37 def __init__(self, args):
38 # Read flatbuffer file descriptor using argument
39 self.tflite_file = args.input_file
41 # Set print level (0 ~ 1)
42 self.print_level = args.verbose
43 if (args.verbose > 1):
45 if (args.verbose < 0):
48 # Set tensor index list to print information
49 self.print_all_tensor = True
50 if (args.tensor != None):
51 if (len(args.tensor) != 0):
52 self.print_all_tensor = False
53 self.print_tensor_index = []
54 for tensor_index in args.tensor:
55 self.print_tensor_index.append(int(tensor_index))
57 # Set operator index list to print information
58 self.print_all_operator = True
59 if (args.operator != None):
60 if (len(args.operator) != 0):
61 self.print_all_operator = False
62 self.print_operator_index = []
63 for operator_index in args.operator:
64 self.print_operator_index.append(int(operator_index))
70 self.save_config = True
73 self.save_prefix = args.prefix
75 def PrintModel(self, model_name, op_parser):
76 printer = SubgraphPrinter(self.print_level, op_parser, model_name)
78 if self.print_all_tensor == False:
79 printer.SetPrintSpecificTensors(self.print_tensor_index)
81 if self.print_all_operator == False:
82 printer.SetPrintSpecificOperators(self.print_operator_index)
86 def SaveModel(self, model_name, op_parser):
87 saver = ModelSaver(model_name, op_parser)
89 if self.save_config == True:
90 saver.SaveConfigInfo(self.save_prefix)
93 # Generate Model: top structure of tflite model file
94 buf = self.tflite_file.read()
96 tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
98 stats = graph_stats.GraphStats()
99 # Model file can have many models
100 for subgraph_index in range(tf_model.SubgraphsLength()):
101 tf_subgraph = tf_model.Subgraphs(subgraph_index)
102 model_name = "#{0} {1}".format(subgraph_index, tf_subgraph.Name())
103 # 0th subgraph is main subgraph
104 if (subgraph_index == 0):
105 model_name += " (MAIN)"
108 op_parser = OperatorParser(tf_model, tf_subgraph)
111 stats += graph_stats.CalcGraphStats(op_parser)
113 if self.save == False:
114 # print all of operators or requested objects
115 self.PrintModel(model_name, op_parser)
117 # save all of operators in this model
118 self.SaveModel(model_name, op_parser)
120 print('==== Model Stats ({} Subgraphs) ===='.format(tf_model.SubgraphsLength()))
122 graph_stats.PrintGraphStats(stats, self.print_level)
125 if __name__ == '__main__':
126 # Define argument and read
127 arg_parser = argparse.ArgumentParser()
128 arg_parser.add_argument(
129 "input_file", type=argparse.FileType('rb'), help="tflite file to read")
130 arg_parser.add_argument(
131 '-v', '--verbose', type=int, default=1, help="set print level (0~1, default: 1)")
132 arg_parser.add_argument(
133 '-t', '--tensor', nargs='*', help="tensor ID to print information (default: all)")
134 arg_parser.add_argument(
138 help="operator ID to print information (default: all)")
139 arg_parser.add_argument(
143 help="Save the configuration file per operator")
144 arg_parser.add_argument(
145 '-p', '--prefix', help="file prefix to be saved (with -c/--config option)")
146 args = arg_parser.parse_args()
149 TFLiteModelFileParser(args).main()