From 011503b263b7186541ec00f32708413c7a472046 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=ED=95=9C=EC=A2=85/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Mon, 23 Jul 2018 19:03:27 +0900 Subject: [PATCH] [model_parser] Print activation function type (#2052) * [model_parser] Print activation function type Parse activation function type of the operator and print. This also make it possible to print other options in BuiltinOptions for all BuiltinOperators. Signed-off-by: Hanjoung Lee * Extract GetBuiltinOptions to be global function --- tools/tflitefile_tool/operator_wrapping.py | 59 ++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tools/tflitefile_tool/operator_wrapping.py b/tools/tflitefile_tool/operator_wrapping.py index 6ffcb2f..b7f5eb9 100755 --- a/tools/tflitefile_tool/operator_wrapping.py +++ b/tools/tflitefile_tool/operator_wrapping.py @@ -34,6 +34,50 @@ def GetStrTensorIndex(tensors): return return_string +def GetBuiltinOptions(tf_operator): + import tflite.Conv2DOptions + import tflite.DepthwiseConv2DOptions + import tflite.Pool2DOptions + import tflite.FullyConnectedOptions + import tflite.SoftmaxOptions + import tflite.ConcatenationOptions + import tflite.ReshapeOptions + import tflite.AddOptions + import tflite.SubOptions + import tflite.MulOptions + import tflite.DivOptions + import tflite.ResizeBilinearOptions + import tflite.StridedSliceOptions + import tflite.CastOptions + import tflite.TopKV2Options + import tflite.GatherOptions + + bo = tflite.BuiltinOptions.BuiltinOptions() + bo_gen = { + bo.Conv2DOptions: tflite.Conv2DOptions.Conv2DOptions, + bo.DepthwiseConv2DOptions: tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions, + bo.Pool2DOptions: tflite.Pool2DOptions.Pool2DOptions, + bo.FullyConnectedOptions: tflite.FullyConnectedOptions.FullyConnectedOptions, + bo.SoftmaxOptions: tflite.SoftmaxOptions.SoftmaxOptions, + bo.ConcatenationOptions: tflite.ConcatenationOptions.ConcatenationOptions, + bo.ReshapeOptions: tflite.ReshapeOptions.ReshapeOptions, + bo.AddOptions: tflite.AddOptions.AddOptions, + bo.SubOptions: tflite.SubOptions.SubOptions, + bo.MulOptions: tflite.MulOptions.MulOptions, + bo.DivOptions: tflite.DivOptions.DivOptions, + bo.ResizeBilinearOptions: tflite.ResizeBilinearOptions.ResizeBilinearOptions, + bo.StridedSliceOptions: tflite.StridedSliceOptions.StridedSliceOptions, + bo.CastOptions: tflite.CastOptions.CastOptions, + bo.TopKV2Options: tflite.TopKV2Options.TopKV2Options, + bo.GatherOptions: tflite.GatherOptions.GatherOptions + } + + options_table = tf_operator.BuiltinOptions() + options = bo_gen[tf_operator.BuiltinOptionsType()]() + options.Init(options_table.Bytes, options_table.Pos) + return options + + class Operator(object): def __init__(self, operator_idx, tf_operator, input_tensors, output_tensors, opcode_str): @@ -57,6 +101,8 @@ class Operator(object): print("Operator {0}: {1} (ops: {2}, cycls: {3})".format( self.operator_idx, self.opcode_str, counts, cycles)) + self.PrintOptionInfo() + print("\tInput Tensors" + GetStrTensorIndex(self.inputs)) for tensor in self.inputs: tensor.PrintInfo("\t\t") @@ -64,6 +110,19 @@ class Operator(object): for tensor in self.outputs: tensor.PrintInfo("\t\t") + def PrintOptionInfo(self): + options = GetBuiltinOptions(self.tf_operator) + + # fused activation function + try: + activation_code = options.FusedActivationFunction() + fused_activation = str( + activation_code) # TODO print the name, not the integer code + print("\tFused Activation: " + fused_activation) + except AttributeError: + # This operator does not support FusedActivationFunction + pass + def CountOperations(self): opcode_str = self.opcode_str # FIXME: if there would be a class for ops_counters, we can delete this -- 2.7.4