[tool] make pb_info.py prints name filtered by prefix (#1981)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 17 Jul 2018 10:59:51 +0000 (19:59 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Tue, 17 Jul 2018 10:59:51 +0000 (19:59 +0900)
pb_info.py prints name filtered by prefix.
for example running `pb_info.py --name_prefix='Model/rnn' <more options>` prints operations with their names starting with 'Model/rnn'

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
tools/pbfile_tool/pb_info.py
tools/pbfile_tool/readme.md

index 8101b05..110b15e 100755 (executable)
@@ -78,39 +78,51 @@ def print_operation(op, op_count):
     print("")  # new line
 
 
-def print_graph_info(pb_path, optype_substring):
+def print_graph_info(pb_path, optype_substring, name_prefix):
     with tf.Session() as sess:
         importGraphIntoSession(sess, pb_path)
 
+        op_seq = 1
         op_count = 1
         graph = sess.graph
         ops = graph.get_operations()
         for op in ops:
-            if optype_substring == "*":
-                print_operation(op, op_count)
-            elif op.type.lower().find(optype_substring.lower()) != -1:
-                print_operation(op, op_count)
+            if optype_substring == "*" and (name_prefix == None
+                                            or op.name.startswith(name_prefix)):
+                print_operation(op, op_seq)
+                op_count += 1
+            elif op.type.lower().find(optype_substring.lower()) != -1 and (
+                    name_prefix == None or op.name.startswith(name_prefix)):
+                print_operation(op, op_seq)
+                op_count += 1
             else:
                 print("skipping {}, name = {}".format(op.type, op.name))
-            op_count += 1
+            op_seq += 1
+
+        print("")
+        print("Total number of operations : " + str(op_count))
+        print("")
 
 
-def print_summary(pb_path, optype_substring):
+def print_summary(pb_path, optype_substring, name_prefix):
     op_map = {}
+    op_count = 0
     with tf.Session() as sess:
         importGraphIntoSession(sess, pb_path)
 
-        op_count = 1
         graph = sess.graph
         ops = graph.get_operations()
         for op in ops:
             process = False
-            if optype_substring == "*":
+            if optype_substring == "*" and (name_prefix == None
+                                            or op.name.startswith(name_prefix)):
                 process = True
-            elif op.type.lower().find(optype_substring.lower()) != -1:
+            elif op.type.lower().find(optype_substring.lower()) != -1 and (
+                    name_prefix == None or op.name.startswith(name_prefix)):
                 process = True
 
             if process:
+                op_count += 1
                 if op_map.get(op.type) == None:
                     op_map[op.type] = 1
                 else:
@@ -118,10 +130,12 @@ def print_summary(pb_path, optype_substring):
 
         # print op list
         print("")
-        print("Total number of operation types : " + str(len(op_map.keys())))
-        print("")
         for op_type, count in op_map.items():
             print("\t" + op_type + " : \t" + str(count))
+        print("")
+        print("Total number of operations : " + str(op_count))
+        print("Total number of operation types : " + str(len(op_map.keys())))
+        print("")
 
 
 if __name__ == "__main__":
@@ -134,10 +148,11 @@ if __name__ == "__main__":
         help="substring of operations. only info of these operasions will be printed.")
     parser.add_argument(
         "--summary", help="print summary of operations", action="store_true")
+    parser.add_argument("--name_prefix", help="filtered by speficied name prefix")
 
     args = parser.parse_args()
 
     if args.summary:
-        print_summary(args.pb_file, args.op_subst)
+        print_summary(args.pb_file, args.op_subst, args.name_prefix)
     else:
-        print_graph_info(args.pb_file, args.op_subst)
+        print_graph_info(args.pb_file, args.op_subst, args.name_prefix)
index 15594e6..7b3fe1a 100644 (file)
@@ -8,3 +8,6 @@
         - pass "*" as the second param to print all operations
     - `./tools/pbfile_tool/pb_info.py pbfile_path "*" --summary`
         - prints the list of operations and their counts
+    - `./tools/pbfile_tool/pb_info.py pbfile_path "*" --summary --name_prefix=Model/rnn`
+        - prints the summary of operations of which names start `Model/rnn`
+