Model parser: define class (#1291)
author오형석/동작제어Lab(SR)/Senior Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 23 May 2018 00:42:57 +0000 (09:42 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 23 May 2018 00:42:57 +0000 (09:42 +0900)
Define model parser class: TFLiteModelFileParser

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
tools/tflitefile_tool/model_parser.py

index 18d4acf..9e3d5df 100755 (executable)
@@ -75,111 +75,108 @@ def SetBuiltinOpcodeStr():
     BuiltinOpcodeStrList[53] = "CAST"
 
 
-def PrintOperatorInfo(tf_model, tf_subgraph):
+class TFLiteModelFileParser(object):
+    def __init__(self, args):
+        # Read flatbuffer file descriptor using argument
+        self.tflite_file = args.input_file
 
-    print("Operators list\n")
+    def PrintOperatorInfo(self, tf_model, tf_subgraph):
+        print("Operators list\n")
 
-    for operator_idx in range(tf_subgraph.OperatorsLength()):
-        tf_operator = tf_subgraph.Operators(operator_idx)
-        operator_inputs = tf_operator.InputsAsNumpy()
-        operator_outputs = tf_operator.OutputsAsNumpy()
+        for operator_idx in range(tf_subgraph.OperatorsLength()):
+            tf_operator = tf_subgraph.Operators(operator_idx)
+            operator_inputs = tf_operator.InputsAsNumpy()
+            operator_outputs = tf_operator.OutputsAsNumpy()
 
-        opcode_list_idx = tf_operator.OpcodeIndex()
-        opcode_id = tf_model.OperatorCodes(opcode_list_idx).BuiltinCode()
-        opcode_str = BuiltinOpcodeStrList[opcode_id]
+            opcode_list_idx = tf_operator.OpcodeIndex()
+            opcode_id = tf_model.OperatorCodes(opcode_list_idx).BuiltinCode()
+            opcode_str = BuiltinOpcodeStrList[opcode_id]
 
-        if opcode_id == 32:
-            # Custom operator
-            custom_operator = tf_model.OperatorCodes(tf_operator.OpcodeIndex())
-            custom_op_name = custom_operator.CustomCode().decode('utf-8')
-            opcode_str = opcode_str + "(" + custom_op_name + ")"
+            if opcode_id == 32:
+                # Custom operator
+                custom_operator = tf_model.OperatorCodes(tf_operator.OpcodeIndex())
+                custom_op_name = custom_operator.CustomCode().decode('utf-8')
+                opcode_str = opcode_str + "(" + custom_op_name + ")"
 
-        print("Operator " + str(operator_idx) + ": " + opcode_str)
-        print("\tInput Tensors" + str(operator_inputs))
-        print("\tOutput Tensors" + str(operator_outputs))
+            print("Operator " + str(operator_idx) + ": " + opcode_str)
+            print("\tInput Tensors" + str(operator_inputs))
+            print("\tOutput Tensors" + str(operator_outputs))
 
-        print('')
-
-    print('')
-
-
-def GetShapeStringFromTensor(tf_tensor):
-    if tf_tensor.ShapeLength() == 0:
-        return "Scalar"
-
-    return_string = "["
-
-    for shape_idx in range(tf_tensor.ShapeLength()):
-        if (shape_idx != 0):
-            return_string += ", "
-        return_string += str(tf_tensor.Shape(shape_idx))
-
-    return_string += "]"
+            print('')
 
-    return return_string
+        print('')
 
+    def GetShapeStringFromTensor(self, tf_tensor):
+        if tf_tensor.ShapeLength() == 0:
+            return "Scalar"
 
-def PrintTensorInfo(tf_model, tf_subgraph):
+        return_string = "["
 
-    print("Tensor-Buffer mapping & shape\n")
+        for shape_idx in range(tf_tensor.ShapeLength()):
+            if (shape_idx != 0):
+                return_string += ", "
+            return_string += str(tf_tensor.Shape(shape_idx))
 
-    for tensor_idx in range(tf_subgraph.TensorsLength()):
-        tf_tensor = tf_subgraph.Tensors(tensor_idx)
-        buffer_idx = tf_tensor.Buffer()
-        tf_buffer = tf_model.Buffers(buffer_idx)
-        isEmpty = "Filled"
-        if (tf_buffer.DataLength() == 0):
-            isEmpty = " Empty"
-        shape_str = GetShapeStringFromTensor(tf_tensor)
+        return_string += "]"
 
-        shape_name = ""
-        if tf_tensor.Name() != 0:
-            shape_name = tf_tensor.Name()
+        return return_string
 
-        print_str = "Tensor {0:4} : buffer {1:4} | {2} | Shape {3} | {4}".format(
-            tensor_idx, buffer_idx, isEmpty, shape_str, shape_name)
-        print(print_str)
+    def PrintTensorInfo(self, tf_model, tf_subgraph):
+        print("Tensor-Buffer mapping & shape\n")
 
-    print('')
+        for tensor_idx in range(tf_subgraph.TensorsLength()):
+            tf_tensor = tf_subgraph.Tensors(tensor_idx)
+            buffer_idx = tf_tensor.Buffer()
+            tf_buffer = tf_model.Buffers(buffer_idx)
+            isEmpty = "Filled"
+            if (tf_buffer.DataLength() == 0):
+                isEmpty = " Empty"
+            shape_str = self.GetShapeStringFromTensor(tf_tensor)
 
+            shape_name = ""
+            if tf_tensor.Name() != 0:
+                shape_name = tf_tensor.Name()
 
-def main(args):
-    # Read flatbuffer file descriptor using argument
-    tflite_file = args.input_file
+            print_str = "Tensor {0:4} : buffer {1:4} | {2} | Shape {3} | {4}".format(
+                tensor_idx, buffer_idx, isEmpty, shape_str, shape_name)
+            print(print_str)
 
-    # Built-in operator string table
-    SetBuiltinOpcodeStr()
+        print('')
 
-    # Generate Model: top structure of tflite model file
-    buf = tflite_file.read()
-    buf = bytearray(buf)
-    tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
+    def main(self):
+        # Generate Model: top structure of tflite model file
+        buf = self.tflite_file.read()
+        buf = bytearray(buf)
+        tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)
 
-    # Model file can have many models
-    # 1st subgraph is main model
-    model_name = "Main model"
-    for subgraph_index in range(tf_model.SubgraphsLength()):
-        tf_subgraph = tf_model.Subgraphs(subgraph_index)
-        if (subgraph_index != 0):
-            model_name = "Model #" + str(subgraph_index)
+        # Model file can have many models
+        # 1st subgraph is main model
+        model_name = "Main model"
+        for subgraph_index in range(tf_model.SubgraphsLength()):
+            tf_subgraph = tf_model.Subgraphs(subgraph_index)
+            if (subgraph_index != 0):
+                model_name = "Model #" + str(subgraph_index)
 
-        print("[" + model_name + "]\n")
+            print("[" + model_name + "]\n")
 
-        # Model inputs & outputs
-        model_inputs = tf_subgraph.InputsAsNumpy()
-        model_outputs = tf_subgraph.OutputsAsNumpy()
+            # Model inputs & outputs
+            model_inputs = tf_subgraph.InputsAsNumpy()
+            model_outputs = tf_subgraph.OutputsAsNumpy()
 
-        print(model_name + " input tensors: " + str(model_inputs))
-        print(model_name + " output tensors: " + str(model_outputs))
+            print(model_name + " input tensors: " + str(model_inputs))
+            print(model_name + " output tensors: " + str(model_outputs))
 
-        # Operators and length of operators
-        PrintOperatorInfo(tf_model, tf_subgraph)
+            # Operators and length of operators
+            self.PrintOperatorInfo(tf_model, tf_subgraph)
 
-        # tensor and length of tensor
-        PrintTensorInfo(tf_model, tf_subgraph)
+            # tensor and length of tensor
+            self.PrintTensorInfo(tf_model, tf_subgraph)
 
 
 if __name__ == '__main__':
+    # Built-in operator string table
+    SetBuiltinOpcodeStr()
+
     # Define argument and read
     arg_parser = argparse.ArgumentParser()
     arg_parser.add_argument(
@@ -187,4 +184,4 @@ if __name__ == '__main__':
     args = arg_parser.parse_args()
 
     # Call main function
-    main(args)
+    TFLiteModelFileParser(args).main()