[model_parser] Print options of Conv2D, AvgPool2D, MaxPool2D (#3573)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 14 Nov 2018 04:34:31 +0000 (13:34 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 14 Nov 2018 04:34:30 +0000 (13:34 +0900)
Through this commit, model parser prints options of Conv2D, AvgPool2D, MaxPool2D such as stride, padding, etc.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
tools/tflitefile_tool/operator_printer.py
tools/tflitefile_tool/operator_wrapping.py
tools/tflitefile_tool/option_printer.py [new file with mode: 0644]

index 779f6de..1ed4811 100644 (file)
@@ -16,6 +16,7 @@
 
 from operator_wrapping import Operator
 from tensor_printer import TensorPrinter
+from option_printer import OptionPrinter
 from perf_predictor import PerfPredictor
 
 
@@ -63,3 +64,8 @@ class OperatorPrinter(object):
         print("\tOutput Tensors" + GetStrTensorIndex(self.operator.outputs))
         for tensor in self.operator.outputs:
             TensorPrinter(self.verbose, tensor).PrintInfo("\t\t")
+
+        # operator option
+        # Some operations does not have option. In such case no option is printed
+        OptionPrinter(self.verbose, self.operator.opcode_str,
+                      self.operator.options).PrintInfo("\t")
index 55d9172..98d9e82 100755 (executable)
@@ -87,6 +87,8 @@ class Operator(object):
         self.operation = Operation(self.tf_operator, self.opcode_str, self.inputs,
                                    self.outputs)
         self.fused_activation = "NONE"
+        self.options = OptionLoader.GetBuiltinOptions(
+            self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions())
         self.SetupFusedActivation()
 
     def SetupFusedActivation(self):
diff --git a/tools/tflitefile_tool/option_printer.py b/tools/tflitefile_tool/option_printer.py
new file mode 100644 (file)
index 0000000..0e55c7f
--- /dev/null
@@ -0,0 +1,54 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class OptionPrinter(object):
+    def __init__(self, verbose, op_name, options):
+        self.verbose = verbose
+        self.op_name = op_name
+        self.options = options
+
+    def GetPadding(self):
+        if self.options.Padding() == 0:
+            return "SAME"
+        elif self.options.Padding() == 1:
+            return "VALID"
+        else:
+            return "** wrong padding value **"
+
+    def PrintInfo(self, tab=""):
+        if (self.verbose < 1):
+            pass
+
+        if (self.op_name == "AVERAGE_POOL_2D" or self.op_name == "MAX_POOL_2D"):
+            print("{}Options".format(tab))
+
+            print("{}\t{}, {}, {}".format(
+                tab, "Filter W:H = {}:{}".format(self.options.FilterWidth(),
+                                                 self.options.FilterHeight()),
+                "Stride W:H = {}:{}".format(self.options.StrideW(),
+                                            self.options.StrideH()),
+                "Padding = {}".format(self.GetPadding())))
+
+        elif (self.op_name == "CONV_2D"):
+            print("{}Options".format(tab))
+
+            print("{}\t{}, {}, {}".format(
+                tab, "Stride W:H = {}:{}".format(self.options.StrideW(),
+                                                 self.options.StrideH()),
+                "Dilation W:H = {}:{}".format(self.options.DilationWFactor(),
+                                              self.options.DilationHFactor()),
+                "Padding = {}".format(self.GetPadding())))