TFLiteFile tool: Fix shape bug, option (#1265)
author오형석/동작제어Lab(SR)/Senior Engineer/삼성전자 <hseok82.oh@samsung.com>
Fri, 18 May 2018 01:06:32 +0000 (10:06 +0900)
committer김정현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh0822.kim@samsung.com>
Fri, 18 May 2018 01:06:32 +0000 (10:06 +0900)
Fix tensor shape bug when tensor is scalar
Fix generating option for new supported operators

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
tools/tflitefile_tool/select_operator.py

index fa7a268..c235d17 100755 (executable)
@@ -138,11 +138,11 @@ def GenerateTensor(new_builder, selected_tensor, used_buffers_dic):
 
     # Create shape vector for tensor
     shape_num = selected_tensor.ShapeLength()
+    tflite.Tensor.TensorStartShapeVector(new_builder, shape_num)
     if shape_num != 0:
-        tflite.Tensor.TensorStartShapeVector(new_builder, shape_num)
         for shape_idx in reversed(range(shape_num)):
             new_builder.PrependInt32(selected_tensor.Shape(shape_idx))
-        new_shape = new_builder.EndVector(shape_num)
+    new_shape = new_builder.EndVector(shape_num)
 
     # Create tensor_type
     tensor_type = selected_tensor.Type()
@@ -163,8 +163,7 @@ def GenerateTensor(new_builder, selected_tensor, used_buffers_dic):
 
     # Create tensor
     tflite.Tensor.TensorStart(new_builder)
-    if shape_num != 0:
-        tflite.Tensor.TensorAddShape(new_builder, new_shape)
+    tflite.Tensor.TensorAddShape(new_builder, new_shape)
     tflite.Tensor.TensorAddType(new_builder, tensor_type)
     tflite.Tensor.TensorAddBuffer(new_builder, new_buffer_idx)
     if name_string != "":
@@ -206,6 +205,13 @@ import tflite.FullyConnectedOptions
 import tflite.SoftmaxOptions
 import tflite.ConcatenationOptions
 import tflite.ReshapeOptions
+import tflite.AddOptions
+import tflite.SubOptions
+import tflite.MulOptions
+import tflite.DivOptions
+import tflite.ResizeBilinearOptions
+import tflite.StridedSliceOptions
+import tflite.CastOptions
 
 
 def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type):
@@ -314,6 +320,86 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t
             tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape)
         return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder)
 
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions:
+
+        add_option = tflite.AddOptions.AddOptions()
+        add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
+
+        tflite.AddOptions.AddOptionsStart(new_builder)
+        tflite.AddOptions.AddOptionsAddFusedActivationFunction(
+            new_builder, add_option.FusedActivationFunction())
+        return tflite.AddOptions.AddOptionsEnd(new_builder)
+
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions:
+
+        sub_option = tflite.SubOptions.SubOptions()
+        sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
+
+        tflite.SubOptions.SubOptionsStart(new_builder)
+        tflite.SubOptions.SubOptionsAddFusedActivationFunction(
+            new_builder, sub_option.FusedActivationFunction())
+        return tflite.SubOptions.SubOptionsEnd(new_builder)
+
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().MulOptions:
+
+        mul_option = tflite.MulOptions.MulOptions()
+        mul_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
+
+        tflite.MulOptions.MulOptionsStart(new_builder)
+        tflite.MulOptions.MulOptionsAddFusedActivationFunction(
+            new_builder, mul_option.FusedActivationFunction())
+        return tflite.MulOptions.MulOptionsEnd(new_builder)
+
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DivOptions:
+
+        div_option = tflite.DivOptions.DivOptions()
+        div_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
+
+        tflite.DivOptions.DivOptionsStart(new_builder)
+        tflite.DivOptions.DivOptionsAddFusedActivationFunction(
+            new_builder, div_option.FusedActivationFunction())
+        return tflite.DivOptions.DivOptionsEnd(new_builder)
+
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions(
+    ).ResizeBilinearOptions:
+
+        resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions()
+        resize_bilinear_option.Init(selected_builtin_option.Bytes,
+                                    selected_builtin_option.Pos)
+
+        tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder)
+        tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners(
+            new_builder, resize_bilinear_option.AlignCorners())
+        return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder)
+
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().StridedSliceOptions:
+
+        stride_slice_option = tflite.StridedSliceOptions.StridedSliceOptions(new_builder)
+        stride_slice_option.Init(selected_builtin_option.Bytes,
+                                 selected_builtin_option.Pos)
+
+        tflite.StridedSliceOptions.StridedSliceOptionsStart(new_builder)
+        tflite.StridedSliceOptions.StridedSliceOptionsAddBeginMask(
+            new_builder, stride_slice_option.BeginMask())
+        tflite.StridedSliceOptions.StridedSliceOptionsAddEndMask(
+            new_builder, stride_slice_option.EndMask())
+        tflite.StridedSliceOptions.StridedSliceOptionsAddEllipsisMask(
+            new_builder, stride_slice_option.EllipsisMask())
+        tflite.StridedSliceOptions.StridedSliceOptionsAddNewAxisMask(
+            new_builder, stride_slice_option.NewAxisMask())
+        tflite.StridedSliceOptions.StridedSliceOptionsAddShrinkAxisMask(
+            new_builder, stride_slice_option.ShrinkAxisMask())
+
+        return tflite.StridedSliceOptions.StridedSliceOptionsEnd(new_builder)
+
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions:
+
+        cast_option = tflite.CastOptions.CastOptions(new_builder)
+        cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
+
+        tflite.CastOptions.CastOptionsStart(new_builder)
+        return tflite.CastOptions.CastOptionsEnd(new_builder)
+
     # Cannot handle builtin option type yet
     return 0