Update model parser: multiple subgraph (#210)
author오형석/동작제어Lab(SR)/Senior Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 27 Mar 2018 06:22:11 +0000 (15:22 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Tue, 27 Mar 2018 06:22:11 +0000 (15:22 +0900)
Update model parser

- Can handle multiple subgraph
- Change while loop
- Remove opcode list generation

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
tools/tflitefile_tool/model_parser.py

index fc55e1a..6b9c022 100755 (executable)
@@ -81,31 +81,20 @@ def SetBuiltinOpcodeStr():
     BuiltinOpcodeStrList[48] = "TOPK_V2"
     BuiltinOpcodeStrList[49] = "SPLIT"
 
-def PrintOperatorInfo(tf_model):
-    # Think only one subgraph
-    tf_subgraph = tf_model.Subgraphs(0)
-
-    # Make built-in operator table
-    BuiltInOpcodeList = {}
-    opcode_list_num = tf_model.OperatorCodesLength()
-    opcode_list_idx = 0
-    while opcode_list_idx < opcode_list_num:
-        BuiltInOpcodeList[opcode_list_idx] = tf_model.OperatorCodes(opcode_list_idx).BuiltinCode()
-        opcode_list_idx += 1
-
-    tf_operators_num = tf_subgraph.OperatorsLength()
-    operator_idx = 0
+def PrintOperatorInfo(tf_model, tf_subgraph):
 
     print("Operators list\n")
 
-    while operator_idx < tf_operators_num:
+    for operator_idx in range(tf_subgraph.OperatorsLength()):
         tf_operator = tf_subgraph.Operators(operator_idx)
         operator_inputs = tf_operator.InputsAsNumpy()
         operator_outputs = tf_operator.OutputsAsNumpy()
-        opcode_idx = BuiltInOpcodeList[tf_operator.OpcodeIndex()]
-        opcode_str = BuiltinOpcodeStrList[opcode_idx]
 
-        if opcode_idx == 32:
+        opcode_list_idx = tf_operator.OpcodeIndex()
+        opcode_id = tf_model.OperatorCodes(opcode_list_idx).BuiltinCode()
+        opcode_str = BuiltinOpcodeStrList[opcode_id]
+
+        if opcode_id == 32:
             # Custom operator
             custom_operator = tf_model.OperatorCodes(tf_operator.OpcodeIndex())
             opcode_str = opcode_str + "(" + custom_operator.CustomCode() + ")"
@@ -115,21 +104,14 @@ def PrintOperatorInfo(tf_model):
         print("\tOutput Tensors" + str(operator_outputs))
 
         print('')
-        operator_idx += 1
 
     print('')
 
-def PrintTensorInfo(tf_model):
-    # Think only one subgraph
-    tf_subgraph = tf_model.Subgraphs(0)
-
-    tf_tensor_num = tf_subgraph.TensorsLength()
-
-    tensor_idx = 0
+def PrintTensorInfo(tf_model, tf_subgraph):
 
     print("Tensor-Buffer mapping\n")
 
-    while tensor_idx < tf_tensor_num:
+    for tensor_idx in range(tf_subgraph.TensorsLength()):
         tf_tensor = tf_subgraph.Tensors(tensor_idx)
         buffer_idx = tf_tensor.Buffer()
         tf_buffer = tf_model.Buffers(buffer_idx)
@@ -138,7 +120,6 @@ def PrintTensorInfo(tf_model):
             isEmpty = " Empty"
 
         print("Tensor " + str(tensor_idx) + ": buffer " + str(buffer_idx) + isEmpty)
-        tensor_idx += 1
 
     print('')
 
@@ -154,21 +135,28 @@ def main():
     buf = bytearray(buf)
     tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
 
-    # Think only one subgraph
-    tf_subgraph = tf_model.Subgraphs(0)
+    # Model file can have many models
+    # 1st subgraph is main model
+    model_name = "Main model"
+    for subgraph_index in range(tf_model.SubgraphsLength()):
+        tf_subgraph = tf_model.Subgraphs(subgraph_index)
+        if (subgraph_index != 0):
+            model_name = "Model #" + str(subgraph_index)
+
+        print("[" + model_name + "]\n")
 
-    # Model inputs & outputs
-    model_inputs = tf_subgraph.InputsAsNumpy()
-    model_outputs = tf_subgraph.OutputsAsNumpy()
+        # Model inputs & outputs
+        model_inputs = tf_subgraph.InputsAsNumpy()
+        model_outputs = tf_subgraph.OutputsAsNumpy()
 
-    print("Model input tensors: " + str(model_inputs))
-    print("Model output tensors: " + str(model_outputs))
+        print(model_name + " input tensors: " + str(model_inputs))
+        print(model_name + " output tensors: " + str(model_outputs))
 
-    # Operators and length of operators
-    PrintOperatorInfo(tf_model)
+        # Operators and length of operators
+        PrintOperatorInfo(tf_model, tf_subgraph)
 
-    # tensor and length of tensor
-    PrintTensorInfo(tf_model)
+        # tensor and length of tensor
+        PrintTensorInfo(tf_model, tf_subgraph)
 
 if __name__ == '__main__':
     main()