[model_parser] Print activation function type (#2052)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Mon, 23 Jul 2018 10:03:27 +0000 (19:03 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Mon, 23 Jul 2018 10:03:27 +0000 (19:03 +0900)
* [model_parser] Print activation function type

Parse activation function type of the operator and print. This also
make it possible to print other options in BuiltinOptions for all
BuiltinOperators.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
* Extract GetBuiltinOptions to be global function

tools/tflitefile_tool/operator_wrapping.py

index 6ffcb2f..b7f5eb9 100755 (executable)
@@ -34,6 +34,50 @@ def GetStrTensorIndex(tensors):
     return return_string
 
 
+def GetBuiltinOptions(tf_operator):
+    import tflite.Conv2DOptions
+    import tflite.DepthwiseConv2DOptions
+    import tflite.Pool2DOptions
+    import tflite.FullyConnectedOptions
+    import tflite.SoftmaxOptions
+    import tflite.ConcatenationOptions
+    import tflite.ReshapeOptions
+    import tflite.AddOptions
+    import tflite.SubOptions
+    import tflite.MulOptions
+    import tflite.DivOptions
+    import tflite.ResizeBilinearOptions
+    import tflite.StridedSliceOptions
+    import tflite.CastOptions
+    import tflite.TopKV2Options
+    import tflite.GatherOptions
+
+    bo = tflite.BuiltinOptions.BuiltinOptions()
+    bo_gen = {
+        bo.Conv2DOptions: tflite.Conv2DOptions.Conv2DOptions,
+        bo.DepthwiseConv2DOptions: tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions,
+        bo.Pool2DOptions: tflite.Pool2DOptions.Pool2DOptions,
+        bo.FullyConnectedOptions: tflite.FullyConnectedOptions.FullyConnectedOptions,
+        bo.SoftmaxOptions: tflite.SoftmaxOptions.SoftmaxOptions,
+        bo.ConcatenationOptions: tflite.ConcatenationOptions.ConcatenationOptions,
+        bo.ReshapeOptions: tflite.ReshapeOptions.ReshapeOptions,
+        bo.AddOptions: tflite.AddOptions.AddOptions,
+        bo.SubOptions: tflite.SubOptions.SubOptions,
+        bo.MulOptions: tflite.MulOptions.MulOptions,
+        bo.DivOptions: tflite.DivOptions.DivOptions,
+        bo.ResizeBilinearOptions: tflite.ResizeBilinearOptions.ResizeBilinearOptions,
+        bo.StridedSliceOptions: tflite.StridedSliceOptions.StridedSliceOptions,
+        bo.CastOptions: tflite.CastOptions.CastOptions,
+        bo.TopKV2Options: tflite.TopKV2Options.TopKV2Options,
+        bo.GatherOptions: tflite.GatherOptions.GatherOptions
+    }
+
+    options_table = tf_operator.BuiltinOptions()
+    options = bo_gen[tf_operator.BuiltinOptionsType()]()
+    options.Init(options_table.Bytes, options_table.Pos)
+    return options
+
+
 class Operator(object):
     def __init__(self, operator_idx, tf_operator, input_tensors, output_tensors,
                  opcode_str):
@@ -57,6 +101,8 @@ class Operator(object):
         print("Operator {0}: {1} (ops: {2}, cycls: {3})".format(
             self.operator_idx, self.opcode_str, counts, cycles))
 
+        self.PrintOptionInfo()
+
         print("\tInput Tensors" + GetStrTensorIndex(self.inputs))
         for tensor in self.inputs:
             tensor.PrintInfo("\t\t")
@@ -64,6 +110,19 @@ class Operator(object):
         for tensor in self.outputs:
             tensor.PrintInfo("\t\t")
 
+    def PrintOptionInfo(self):
+        options = GetBuiltinOptions(self.tf_operator)
+
+        # fused activation function
+        try:
+            activation_code = options.FusedActivationFunction()
+            fused_activation = str(
+                activation_code)  # TODO print the name, not the integer code
+            print("\tFused Activation: " + fused_activation)
+        except AttributeError:
+            # This operator does not support FusedActivationFunction
+            pass
+
     def CountOperations(self):
         opcode_str = self.opcode_str
         # FIXME: if there would be a class for ops_counters, we can delete this