TFLite model file parser
authorHyeongseok Oh <hseok82.oh@samsung.com>
Thu, 22 Mar 2018 10:19:49 +0000 (19:19 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Mon, 26 Mar 2018 04:49:21 +0000 (13:49 +0900)
Print operator and tensor information in TF Lite model file

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
tools/tflitefile_tool/model_parser.py [new file with mode: 0755]

diff --git a/tools/tflitefile_tool/model_parser.py b/tools/tflitefile_tool/model_parser.py
new file mode 100755 (executable)
index 0000000..fc55e1a
--- /dev/null
@@ -0,0 +1,174 @@
+#!/usr/bin/python
+import os
+import sys
+import numpy
+
+sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tflite'))
+sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../externals/flatbuffers/python'))
+
+import flatbuffers
+import tflite.Model
+import tflite.SubGraph
+import tflite.Tensor
+import tflite.Operator
+import tflite.OperatorCode
+import tflite.BuiltinOperator
+
+def OpenTFLiteFile():
+    if len(sys.argv) != 2:
+        print("Error: incorrect argument")
+        print(">", sys.argv[0], " [filename]")
+        print("\t [filename]: tflite file")
+        exit(1)
+
+    input_file_name = sys.argv[1]
+    try:
+        ifp = open(input_file_name, 'rb')
+    except:
+        print("Error: incorrect file name")
+        print("Cannot open tflite model file")
+        exit(1)
+    return ifp
+
+BuiltinOpcodeStrList = {}
+
+def SetBuiltinOpcodeStr():
+    BuiltinOpcodeStrList[0] = "ADD"
+    BuiltinOpcodeStrList[1] = "AVERAGE_POOL_2D"
+    BuiltinOpcodeStrList[2] = "CONCATENATION"
+    BuiltinOpcodeStrList[3] = "CONV_2D"
+    BuiltinOpcodeStrList[4] = "DEPTHWISE_CONV_2D"
+    BuiltinOpcodeStrList[7] = "EMBEDDING_LOOKUP"
+    BuiltinOpcodeStrList[9] = "FULLY_CONNECTED"
+    BuiltinOpcodeStrList[10] = "HASHTABLE_LOOKUP"
+    BuiltinOpcodeStrList[11] = "L2_NORMALIZATION"
+    BuiltinOpcodeStrList[12] = "L2_POOL_2D"
+    BuiltinOpcodeStrList[13] = "LOCAL_RESPONSE_NORMALIZATION"
+    BuiltinOpcodeStrList[14] = "LOGISTIC"
+    BuiltinOpcodeStrList[15] = "LSH_PROJECTION"
+    BuiltinOpcodeStrList[16] = "LSTM"
+    BuiltinOpcodeStrList[17] = "MAX_POOL_2D"
+    BuiltinOpcodeStrList[18] = "MUL"
+    BuiltinOpcodeStrList[19] = "RELU"
+    BuiltinOpcodeStrList[20] = "RELU_N1_TO_1"
+    BuiltinOpcodeStrList[21] = "RELU6"
+    BuiltinOpcodeStrList[22] = "RESHAPE"
+    BuiltinOpcodeStrList[23] = "RESIZE_BILINEAR"
+    BuiltinOpcodeStrList[24] = "RNN"
+    BuiltinOpcodeStrList[25] = "SOFTMAX"
+    BuiltinOpcodeStrList[26] = "SPACE_TO_DEPTH"
+    BuiltinOpcodeStrList[27] = "SVDF"
+    BuiltinOpcodeStrList[28] = "TANH"
+    BuiltinOpcodeStrList[29] = "CONCAT_EMBEDDINGS"
+    BuiltinOpcodeStrList[30] = "SKIP_GRAM"
+    BuiltinOpcodeStrList[31] = "CALL"
+    BuiltinOpcodeStrList[32] = "CUSTOM"
+    BuiltinOpcodeStrList[33] = "EMBEDDING_LOOKUP_SPARSE"
+    BuiltinOpcodeStrList[34] = "PAD"
+    BuiltinOpcodeStrList[35] = "UNIDIRECTIONAL_SEQUENCE_RNN"
+    BuiltinOpcodeStrList[36] = "GATHER"
+    BuiltinOpcodeStrList[37] = "BATCH_TO_SPACE_ND"
+    BuiltinOpcodeStrList[38] = "SPACE_TO_BATCH_ND"
+    BuiltinOpcodeStrList[39] = "TRANSPOSE"
+    BuiltinOpcodeStrList[40] = "MEAN"
+    BuiltinOpcodeStrList[41] = "SUB"
+    BuiltinOpcodeStrList[42] = "DIV"
+    BuiltinOpcodeStrList[43] = "SQUEEZE"
+    BuiltinOpcodeStrList[44] = "UNIDIRECTIONAL_SEQUENCE_LSTM"
+    BuiltinOpcodeStrList[45] = "STRIDED_SLICE"
+    BuiltinOpcodeStrList[46] = "BIDIRECTIONAL_SEQUENCE_RNN"
+    BuiltinOpcodeStrList[47] = "EXP"
+    BuiltinOpcodeStrList[48] = "TOPK_V2"
+    BuiltinOpcodeStrList[49] = "SPLIT"
+
+def PrintOperatorInfo(tf_model):
+    # Think only one subgraph
+    tf_subgraph = tf_model.Subgraphs(0)
+
+    # Make built-in operator table
+    BuiltInOpcodeList = {}
+    opcode_list_num = tf_model.OperatorCodesLength()
+    opcode_list_idx = 0
+    while opcode_list_idx < opcode_list_num:
+        BuiltInOpcodeList[opcode_list_idx] = tf_model.OperatorCodes(opcode_list_idx).BuiltinCode()
+        opcode_list_idx += 1
+
+    tf_operators_num = tf_subgraph.OperatorsLength()
+    operator_idx = 0
+
+    print("Operators list\n")
+
+    while operator_idx < tf_operators_num:
+        tf_operator = tf_subgraph.Operators(operator_idx)
+        operator_inputs = tf_operator.InputsAsNumpy()
+        operator_outputs = tf_operator.OutputsAsNumpy()
+        opcode_idx = BuiltInOpcodeList[tf_operator.OpcodeIndex()]
+        opcode_str = BuiltinOpcodeStrList[opcode_idx]
+
+        if opcode_idx == 32:
+            # Custom operator
+            custom_operator = tf_model.OperatorCodes(tf_operator.OpcodeIndex())
+            opcode_str = opcode_str + "(" + custom_operator.CustomCode() + ")"
+
+        print("Operator " + str(operator_idx) + ": " + opcode_str)
+        print("\tInput Tensors" + str(operator_inputs))
+        print("\tOutput Tensors" + str(operator_outputs))
+
+        print('')
+        operator_idx += 1
+
+    print('')
+
+def PrintTensorInfo(tf_model):
+    # Think only one subgraph
+    tf_subgraph = tf_model.Subgraphs(0)
+
+    tf_tensor_num = tf_subgraph.TensorsLength()
+
+    tensor_idx = 0
+
+    print("Tensor-Buffer mapping\n")
+
+    while tensor_idx < tf_tensor_num:
+        tf_tensor = tf_subgraph.Tensors(tensor_idx)
+        buffer_idx = tf_tensor.Buffer()
+        tf_buffer = tf_model.Buffers(buffer_idx)
+        isEmpty = ""
+        if (tf_buffer.DataLength() == 0):
+            isEmpty = " Empty"
+
+        print("Tensor " + str(tensor_idx) + ": buffer " + str(buffer_idx) + isEmpty)
+        tensor_idx += 1
+
+    print('')
+
+def main():
+    # Read flatbuffer file descriptor using argument
+    tflite_file = OpenTFLiteFile()
+
+    # Built-in operator string table
+    SetBuiltinOpcodeStr()
+
+    # Generate Model: top structure of tflite model file
+    buf = tflite_file.read()
+    buf = bytearray(buf)
+    tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+
+    # Think only one subgraph
+    tf_subgraph = tf_model.Subgraphs(0)
+
+    # Model inputs & outputs
+    model_inputs = tf_subgraph.InputsAsNumpy()
+    model_outputs = tf_subgraph.OutputsAsNumpy()
+
+    print("Model input tensors: " + str(model_inputs))
+    print("Model output tensors: " + str(model_outputs))
+
+    # Operators and length of operators
+    PrintOperatorInfo(tf_model)
+
+    # tensor and length of tensor
+    PrintTensorInfo(tf_model)
+
+if __name__ == '__main__':
+    main()