[tflitefile_tool] Add config option for benchmark tool to save model config file...
author윤지영/On-Device Lab(SR)/Staff Engineer/삼성전자 <jy910.yun@samsung.com>
Thu, 9 May 2019 05:19:21 +0000 (14:19 +0900)
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Thu, 9 May 2019 05:19:21 +0000 (14:19 +0900)
* [tflitefile_tool] Add config option for benchmark tool

This patch adds the `-c` and `--config` options.
This option will print the configuration information of each operator.

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
* Save configuration file

This patch allows to save the configuration info about model

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
tools/tflitefile_tool/config_saver.py [new file with mode: 0644]
tools/tflitefile_tool/model_parser.py
tools/tflitefile_tool/model_printer.py
tools/tflitefile_tool/model_saver.py [new file with mode: 0644]
tools/tflitefile_tool/option_printer.py

diff --git a/tools/tflitefile_tool/config_saver.py b/tools/tflitefile_tool/config_saver.py
new file mode 100644 (file)
index 0000000..2510e87
--- /dev/null
@@ -0,0 +1,131 @@
+#!/usr/bin/python
+
+# Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from operator_wrapping import Operator
+from tensor_printer import TensorPrinter
+from option_printer import OptionPrinter
+from perf_predictor import PerfPredictor
+
+
+class ConfigSaver(object):
+    def __init__(self, file_name, operator):
+        self.file_name = file_name
+        self.operator = operator
+        # Set self.verbose to 2 level to print more information
+        self.verbose = 2
+        self.op_idx = operator.operator_idx
+        self.op_name = operator.opcode_str
+
+        self.f = open(file_name, 'at')
+
+    def __del__(self):
+        self.f.close()
+
+    def SaveInfo(self):
+        self.f.write("[{}]\n".format(self.op_idx))
+        if (self.op_name == 'CONV_2D'):
+            self.SaveConv2DInputs()
+        else:
+            self.SaveInputs()
+
+        self.SaveOutputs()
+
+        self.SaveAttributes()
+
+        self.f.write('\n')
+
+    def SaveConv2DInputs(self):
+        if (len(self.operator.inputs) != 3):
+            raise AssertionError('Conv2D input count should be 3')
+
+        inputs = self.operator.inputs[0]
+        weights = self.operator.inputs[1]
+        bias = self.operator.inputs[2]
+
+        self.f.write("input: {}\n".format(
+            TensorPrinter(self.verbose, inputs).GetShapeString()))
+        self.f.write("input_type: {}\n".format(inputs.type_name))
+        self.f.write("weights: {}\n".format(
+            TensorPrinter(self.verbose, weights).GetShapeString()))
+        self.f.write("weights_type: {}\n".format(weights.type_name))
+        self.f.write("bias: {}\n".format(
+            TensorPrinter(self.verbose, bias).GetShapeString()))
+        self.f.write("bias_type: {}\n".format(bias.type_name))
+
+    def SaveInputs(self):
+        total = len(self.operator.inputs)
+        self.f.write("input_counts: {}\n".format(total))
+        for idx in range(total):
+            tensor = self.operator.inputs[idx]
+            input_shape_str = TensorPrinter(self.verbose, tensor).GetShapeString()
+            self.f.write("input{}: {}\n".format(idx, input_shape_str))
+            self.f.write("input{}_type: {}\n".format(idx, tensor.type_name))
+
+    def SaveOutputs(self):
+        total = len(self.operator.outputs)
+        self.f.write("output_counts: {}\n".format(total))
+        for idx in range(total):
+            tensor = self.operator.outputs[idx]
+            output_shape_str = TensorPrinter(self.verbose, tensor).GetShapeString()
+            self.f.write("output{}: {}\n".format(idx, output_shape_str))
+            self.f.write("output{}_type: {}\n".format(idx, tensor.type_name))
+
+    def SaveFilter(self):
+        self.f.write("filter_w: {}\n".format(self.operator.options.FilterWidth()))
+        self.f.write("filter_h: {}\n".format(self.operator.options.FilterHeight()))
+
+    def SaveStride(self):
+        self.f.write("stride_w: {}\n".format(self.operator.options.StrideW()))
+        self.f.write("stride_h: {}\n".format(self.operator.options.StrideH()))
+
+    def SaveDilation(self):
+        self.f.write("dilation_w: {}\n".format(self.operator.options.DilationWFactor()))
+        self.f.write("dilation_h: {}\n".format(self.operator.options.DilationHFactor()))
+
+    def SavePadding(self):
+        if self.operator.options.Padding() == 0:
+            self.f.write("padding: SAME\n")
+        elif self.operator.options.Padding() == 1:
+            self.f.write("padding: VALID\n")
+
+    def SaveFusedAct(self):
+        if self.operator.fused_activation is not "NONE":
+            self.f.write("fused_act: {}\n".format(self.operator.fused_activation))
+
+    def SaveAttributes(self):
+        # operator option
+        # Some operations does not have option. In such case no option is printed
+        option_str = OptionPrinter(self.verbose, self.op_name,
+                                   self.operator.options).GetOptionString()
+        if self.op_name == 'AVERAGE_POOL_2D' or self.op_name == 'MAX_POOL_2D':
+            self.SaveFilter()
+            self.SaveStride()
+            self.SavePadding()
+        elif self.op_name == 'CONV_2D':
+            self.SaveStride()
+            self.SaveDilation()
+            self.SavePadding()
+        elif self.op_name == 'TRANSPOSE_CONV':
+            self.SaveStride()
+            self.SavePadding()
+        elif self.op_name == 'DEPTHWISE_CONV_2D':
+            self.SaveStride()
+            self.SaveDilation()
+            self.SavePadding()
+            self.f.write("depthmultiplier: {}\n".format(
+                self.opeator.options.DepthMultiplier()))
+
+        self.SaveFusedAct()
index 0edabbb..6f9e1c6 100755 (executable)
@@ -29,6 +29,7 @@ import tflite.SubGraph
 import argparse
 from operator_parser import OperatorParser
 from model_printer import ModelPrinter
+from model_saver import ModelSaver
 from perf_predictor import PerfPredictor
 
 
@@ -62,6 +63,15 @@ class TFLiteModelFileParser(object):
                 for operator_index in args.operator:
                     self.print_operator_index.append(int(operator_index))
 
+        # Set config option
+        self.save = False
+        if args.config:
+            self.save = True
+            self.save_config = True
+
+        if self.save == True:
+            self.save_prefix = args.prefix
+
     def PrintModel(self, model_name, op_parser):
         printer = ModelPrinter(self.print_level, op_parser, model_name)
 
@@ -73,6 +83,12 @@ class TFLiteModelFileParser(object):
 
         printer.PrintInfo()
 
+    def SaveModel(self, model_name, op_parser):
+        saver = ModelSaver(model_name, op_parser)
+
+        if self.save_config == True:
+            saver.SaveConfigInfo(self.save_prefix)
+
     def main(self):
         # Generate Model: top structure of tflite model file
         buf = self.tflite_file.read()
@@ -81,18 +97,22 @@ class TFLiteModelFileParser(object):
 
         # Model file can have many models
         # 1st subgraph is main model
-        model_name = "Main model"
+        model_name = "Main_model"
         for subgraph_index in range(tf_model.SubgraphsLength()):
             tf_subgraph = tf_model.Subgraphs(subgraph_index)
             if (subgraph_index != 0):
-                model_name = "Model #" + str(subgraph_index)
+                model_name = "Model_#" + str(subgraph_index)
 
             # Parse Operators
             op_parser = OperatorParser(tf_model, tf_subgraph, PerfPredictor())
             op_parser.Parse()
 
-            # print all of operators or requested objects
-            self.PrintModel(model_name, op_parser)
+            if self.save == False:
+                # print all of operators or requested objects
+                self.PrintModel(model_name, op_parser)
+            else:
+                # save all of operators in this model
+                self.SaveModel(model_name, op_parser)
 
 
 if __name__ == '__main__':
@@ -109,6 +129,13 @@ if __name__ == '__main__':
         '--operator',
         nargs='*',
         help="operator ID to print information (default: all)")
+    arg_parser.add_argument(
+        '-c',
+        '--config',
+        action='store_true',
+        help="Save the configuration file per operator")
+    arg_parser.add_argument(
+        '-p', '--prefix', help="file prefix to be saved (with -c/--config option)")
     args = arg_parser.parse_args()
 
     # Call main function
index fda0291..0c11d01 100644 (file)
@@ -143,3 +143,4 @@ class ModelPrinter(object):
         from tensor_printer import ConvertBytesToHuman
         print("Expected TOTAL  memory: {0}".format(ConvertBytesToHuman(total_memory)))
         print("Expected FILLED memory: {0}".format(ConvertBytesToHuman(filled_memory)))
+        print('')
diff --git a/tools/tflitefile_tool/model_saver.py b/tools/tflitefile_tool/model_saver.py
new file mode 100644 (file)
index 0000000..15037a1
--- /dev/null
@@ -0,0 +1,36 @@
+#!/usr/bin/python
+
+# Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from config_saver import ConfigSaver
+
+
+class ModelSaver(object):
+    def __init__(self, model_name, op_parser):
+        self.model_name = model_name
+        self.op_parser = op_parser
+
+    def SaveConfigInfo(self, prefix):
+        print("Save model configuration file")
+        for type_str, oper_list in self.op_parser.operators_per_type.items():
+            if prefix:
+                file_name = "{}_{}_{}.config".format(prefix, self.model_name, type_str)
+            else:
+                file_name = "{}_{}.config".format(self.model_name, type_str)
+            print("{} file is generated".format(file_name))
+            with open(file_name, 'wt') as f:
+                f.write("# {}, Total count: {}\n\n".format(type_str, len(oper_list)))
+            for operator in oper_list:
+                ConfigSaver(file_name, operator).SaveInfo()
index 08754f1..15265ad 100644 (file)
@@ -35,35 +35,33 @@ class OptionPrinter(object):
         if (self.options == 0):
             return
 
-        if (self.op_name == "AVERAGE_POOL_2D" or self.op_name == "MAX_POOL_2D"):
+        option_str = self.GetOptionString()
+        if option_str:
             print("{}Options".format(tab))
+            print("{}\t{}".format(tab, option_str))
 
-            print("{}\t{}, {}, {}".format(
-                tab, "Filter W:H = {}:{}".format(self.options.FilterWidth(),
-                                                 self.options.FilterHeight()),
+    def GetOptionString(self):
+        if (self.op_name == "AVERAGE_POOL_2D" or self.op_name == "MAX_POOL_2D"):
+            return "{}, {}, {}".format(
+                "Filter W:H = {}:{}".format(self.options.FilterWidth(),
+                                            self.options.FilterHeight()),
                 "Stride W:H = {}:{}".format(self.options.StrideW(),
                                             self.options.StrideH()),
-                "Padding = {}".format(self.GetPadding())))
-
+                "Padding = {}".format(self.GetPadding()))
         elif (self.op_name == "CONV_2D"):
-            print("{}Options".format(tab))
-
-            print("{}\t{}, {}, {}".format(
-                tab, "Stride W:H = {}:{}".format(self.options.StrideW(),
-                                                 self.options.StrideH()),
+            return "{}, {}, {}".format(
+                "Stride W:H = {}:{}".format(self.options.StrideW(),
+                                            self.options.StrideH()),
                 "Dilation W:H = {}:{}".format(self.options.DilationWFactor(),
                                               self.options.DilationHFactor()),
-                "Padding = {}".format(self.GetPadding())))
-
+                "Padding = {}".format(self.GetPadding()))
         elif (self.op_name == "DEPTHWISE_CONV_2D"):
-            print("{}Options".format(tab))
-
             # yapf: disable
-            print("{}\t{}, {}, {}, {}".format(
-                tab, "Stride W:H = {}:{}".format(self.options.StrideW(),
+            return "{}, {}, {}, {}".format(
+                "Stride W:H = {}:{}".format(self.options.StrideW(),
                                                  self.options.StrideH()),
                 "Dilation W:H = {}:{}".format(self.options.DilationWFactor(),
                                               self.options.DilationHFactor()),
                 "Padding = {}".format(self.GetPadding()),
-                "DepthMultiplier = {}".format(self.options.DepthMultiplier())))
+                "DepthMultiplier = {}".format(self.options.DepthMultiplier()))
             # yapf: enable