From 40a1aeaec2b1c2bf14530b95923f3f378030d508 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=ED=98=84=EC=8B=9D/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Principal=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Wed, 14 Nov 2018 13:34:31 +0900 Subject: [PATCH] [model_parser] Print options of Conv2D, AvgPool2D, MaxPool2D (#3573) Through this commit, model parser prints options of Conv2D, AvgPool2D, MaxPool2D such as stride, padding, etc. Signed-off-by: Hyun Sik Yoon --- tools/tflitefile_tool/operator_printer.py | 6 ++++ tools/tflitefile_tool/operator_wrapping.py | 2 ++ tools/tflitefile_tool/option_printer.py | 54 ++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 tools/tflitefile_tool/option_printer.py diff --git a/tools/tflitefile_tool/operator_printer.py b/tools/tflitefile_tool/operator_printer.py index 779f6de..1ed4811 100644 --- a/tools/tflitefile_tool/operator_printer.py +++ b/tools/tflitefile_tool/operator_printer.py @@ -16,6 +16,7 @@ from operator_wrapping import Operator from tensor_printer import TensorPrinter +from option_printer import OptionPrinter from perf_predictor import PerfPredictor @@ -63,3 +64,8 @@ class OperatorPrinter(object): print("\tOutput Tensors" + GetStrTensorIndex(self.operator.outputs)) for tensor in self.operator.outputs: TensorPrinter(self.verbose, tensor).PrintInfo("\t\t") + + # operator option + # Some operations does not have option. In such case no option is printed + OptionPrinter(self.verbose, self.operator.opcode_str, + self.operator.options).PrintInfo("\t") diff --git a/tools/tflitefile_tool/operator_wrapping.py b/tools/tflitefile_tool/operator_wrapping.py index 55d9172..98d9e82 100755 --- a/tools/tflitefile_tool/operator_wrapping.py +++ b/tools/tflitefile_tool/operator_wrapping.py @@ -87,6 +87,8 @@ class Operator(object): self.operation = Operation(self.tf_operator, self.opcode_str, self.inputs, self.outputs) self.fused_activation = "NONE" + self.options = OptionLoader.GetBuiltinOptions( + self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions()) self.SetupFusedActivation() def SetupFusedActivation(self): diff --git a/tools/tflitefile_tool/option_printer.py b/tools/tflitefile_tool/option_printer.py new file mode 100644 index 0000000..0e55c7f --- /dev/null +++ b/tools/tflitefile_tool/option_printer.py @@ -0,0 +1,54 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class OptionPrinter(object): + def __init__(self, verbose, op_name, options): + self.verbose = verbose + self.op_name = op_name + self.options = options + + def GetPadding(self): + if self.options.Padding() == 0: + return "SAME" + elif self.options.Padding() == 1: + return "VALID" + else: + return "** wrong padding value **" + + def PrintInfo(self, tab=""): + if (self.verbose < 1): + pass + + if (self.op_name == "AVERAGE_POOL_2D" or self.op_name == "MAX_POOL_2D"): + print("{}Options".format(tab)) + + print("{}\t{}, {}, {}".format( + tab, "Filter W:H = {}:{}".format(self.options.FilterWidth(), + self.options.FilterHeight()), + "Stride W:H = {}:{}".format(self.options.StrideW(), + self.options.StrideH()), + "Padding = {}".format(self.GetPadding()))) + + elif (self.op_name == "CONV_2D"): + print("{}Options".format(tab)) + + print("{}\t{}, {}, {}".format( + tab, "Stride W:H = {}:{}".format(self.options.StrideW(), + self.options.StrideH()), + "Dilation W:H = {}:{}".format(self.options.DilationWFactor(), + self.options.DilationHFactor()), + "Padding = {}".format(self.GetPadding()))) -- 2.7.4