op_count += 1
+def print_summary(pb_path, optype_substring):
+ op_map = {}
+ with tf.Session() as sess:
+ importGraphIntoSession(sess, pb_path)
+
+ op_count = 1
+ graph = sess.graph
+ ops = graph.get_operations()
+ for op in ops:
+ process = False
+ if optype_substring == "*":
+ process = True
+ elif op.type.lower().find(optype_substring.lower()) != -1:
+ process = True
+
+ if process:
+ if op_map.get(op.type) == None:
+ op_map[op.type] = 1
+ else:
+ op_map[op.type] += 1
+
+ # print op list
+ print("")
+ print("Total number of operation types : " + str(len(op_map.keys())))
+ print("")
+ for op_type, count in op_map.items():
+ print("\t" + op_type + " : \t" + str(count))
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Prints information inside pb file')
parser.add_argument(
"op_subst",
help="substring of operations. only info of these operasions will be printed.")
+ parser.add_argument(
+ "--summary", help="print summary of operations", action="store_true")
args = parser.parse_args()
- print_graph_info(args.pb_file, args.op_subst)
+ if args.summary:
+ print_summary(args.pb_file, args.op_subst)
+ else:
+ print_graph_info(args.pb_file, args.op_subst)
- first arg: pb file
- second arg: substring of operation. Only operations that has "conv" substring as its type will be printed. (case-insensitive)
- `./tools/pbfile_tool/pb_info.py pbfile_path "*"`
- - pass "*" as the second param to print all operations
+ - pass "*" as the second param to print all operations
+ - `./tools/pbfile_tool/pb_info.py pbfile_path "*" --summary`
+ - prints the list of operations and their counts