[model_parser] Supports printing specific tensors and operators (#3580)
author김용섭/동작제어Lab(SR)/Engineer/삼성전자 <yons.kim@samsung.com>
Thu, 15 Nov 2018 02:09:24 +0000 (11:09 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 15 Nov 2018 02:09:24 +0000 (11:09 +0900)
Now model_parser.py supports argument options for features which prints
specific tensors and operators.

./model_parser.py ./tflite -t 1 2 3   # tensor id
./model_parser.py ./tflite -o 1 2 3   # operator id

Signed-off-by: Yongseop Kim <yons.kim@samsung.com>
tools/tflitefile_tool/model_parser.py
tools/tflitefile_tool/model_printer.py
tools/tflitefile_tool/operator_parser.py
tools/tflitefile_tool/operator_printer.py

index 7c5971a..0edabbb 100755 (executable)
@@ -45,35 +45,33 @@ class TFLiteModelFileParser(object):
             self.print_level = 0
 
         # Set tensor index list to print information
-        # TODO:
-        #   Print tensors in list only
-        #   Print all tensors if argument used and not specified index number
+        self.print_all_tensor = True
         if (args.tensor != None):
-            if (len(args.tensor) == 0):
-                self.print_all_tensor = True
-            else:
+            if (len(args.tensor) != 0):
                 self.print_all_tensor = False
                 self.print_tensor_index = []
-
                 for tensor_index in args.tensor:
                     self.print_tensor_index.append(int(tensor_index))
 
         # Set operator index list to print information
-        # TODO:
-        #   Print operators in list only
-        #   Print all operators if argument used and not specified index number
+        self.print_all_operator = True
         if (args.operator != None):
-            if (len(args.operator) == 0):
-                self.print_all_oeprator = True
-            else:
-                self.print_all_oeprator = False
+            if (len(args.operator) != 0):
+                self.print_all_operator = False
                 self.print_operator_index = []
-
                 for operator_index in args.operator:
                     self.print_operator_index.append(int(operator_index))
 
     def PrintModel(self, model_name, op_parser):
-        ModelPrinter(self.print_level, op_parser, model_name).PrintAll()
+        printer = ModelPrinter(self.print_level, op_parser, model_name)
+
+        if self.print_all_tensor == False:
+            printer.SetPrintSpecificTensors(self.print_tensor_index)
+
+        if self.print_all_operator == False:
+            printer.SetPrintSpecificOperators(self.print_operator_index)
+
+        printer.PrintInfo()
 
     def main(self):
         # Generate Model: top structure of tflite model file
index 019f68e..cac8f68 100644 (file)
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 from operator_printer import OperatorPrinter
+from tensor_printer import TensorPrinter
 
 
 class ModelPrinter(object):
@@ -22,11 +23,34 @@ class ModelPrinter(object):
         self.verbose = verbose
         self.op_parser = op_parser
         self.model_name = model_name
+        self.print_all_tensor = True
+        self.print_tensor_index_list = None
+        self.print_all_operator = True
+        self.print_operator_index_list = None
+
+    def SetPrintSpecificTensors(self, tensor_indices):
+        if len(tensor_indices) != 0:
+            self.print_all_tensor = False
+            self.print_tensor_index_list = tensor_indices
+
+    def SetPrintSpecificOperators(self, operator_indices):
+        if len(operator_indices) != 0:
+            self.print_all_operator = False
+            self.print_operator_index_list = operator_indices
+
+    def PrintInfo(self):
+        if self.print_all_tensor == True and self.print_all_operator == True:
+            self.PrintModelInfo()
+            self.PrintAllOperatorsInList()
+            self.PrintAllTypesInfo()
+
+        if self.print_all_tensor == False:
+            print('')
+            self.PrintSpecificTensors()
 
-    def PrintAll(self):
-        self.PrintModelInfo()
-        self.PrintAllOperatorsInList()
-        self.PrintAllTypesInfo()
+        if self.print_all_operator == False:
+            print('')
+            self.PrintSpecificOperators()
 
     def PrintModelInfo(self):
         print("[" + self.model_name + "]\n")
@@ -89,3 +113,20 @@ class ModelPrinter(object):
 
         print(summary_str)
         print('')
+
+    def PrintSpecificTensors(self):
+        for tensor in self.op_parser.GetAllTensors():
+            if tensor.tensor_idx in self.print_tensor_index_list:
+                printer = TensorPrinter(self.verbose, tensor)
+                printer.PrintInfo()
+                print('')
+        print('')
+
+    def PrintSpecificOperators(self):
+        for operator in self.op_parser.operators_in_list:
+            if operator.operator_idx in self.print_operator_index_list:
+                printer = OperatorPrinter(self.verbose, operator)
+                printer.PrintInfo(self.op_parser.perf_predictor)
+                print('')
+
+        print('')
index 5b080b0..71b1a6d 100755 (executable)
@@ -76,6 +76,18 @@ class OperatorParser(object):
             return_list.append(Tensor(tensor_idx, tf_tensor, tf_buffer))
         return return_list
 
+    def GetAllTensors(self):
+        return_list = list()
+        for tensor_idx in range(self.tf_subgraph.TensorsLength()):
+            if (tensor_idx < 0):
+                return_list.append(Tensor(tensor_idx, 0, 0))
+                continue
+            tf_tensor = self.tf_subgraph.Tensors(tensor_idx)
+            buffer_idx = tf_tensor.Buffer()
+            tf_buffer = self.tf_model.Buffers(buffer_idx)
+            return_list.append(Tensor(tensor_idx, tf_tensor, tf_buffer))
+        return return_list
+
     def AppendOperator(self, operator):
         self.operators_in_list.append(operator)
 
index 1ed4811..9b6f97d 100644 (file)
@@ -37,7 +37,7 @@ class OperatorPrinter(object):
 
     def PrintInfo(self, perf_predictor=None):
         if (self.verbose < 1):
-            pass
+            return
 
         op_str = "Operator {0}: {1}".format(self.operator.operator_idx,
                                             self.operator.opcode_str)
@@ -55,9 +55,10 @@ class OperatorPrinter(object):
             op_str = op_str + "(instrs: {0}, cycls: {1})".format(instrs, cycles)
 
         print(op_str)
-
         print("\tFused Activation: " + self.operator.fused_activation)
+        self.PrintTensors()
 
+    def PrintTensors(self):
         print("\tInput Tensors" + GetStrTensorIndex(self.operator.inputs))
         for tensor in self.operator.inputs:
             TensorPrinter(self.verbose, tensor).PrintInfo("\t\t")