Handle reshape and depthwise convolution in operator selector (#476)
author오형석/동작제어Lab(SR)/Senior Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 9 Apr 2018 14:13:53 +0000 (23:13 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Mon, 9 Apr 2018 14:13:53 +0000 (23:13 +0900)
Operator selector (select_operator.py) can handle new two operators

- Reshape
- Depthwise convolution

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
tools/tflitefile_tool/select_operator.py

index a54d94f..98e25ad 100755 (executable)
@@ -202,10 +202,12 @@ def GenerateTensors(new_builder, selected_subgraph, used_tensors_dic, used_buffe
     return new_builder.EndVector(new_tensor_num)
 
 import tflite.Conv2DOptions
+import tflite.DepthwiseConv2DOptions
 import tflite.Pool2DOptions
 import tflite.FullyConnectedOptions
 import tflite.SoftmaxOptions
 import tflite.ConcatenationOptions
+import tflite.ReshapeOptions
 
 def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type):
 
@@ -221,6 +223,19 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t
         tflite.Conv2DOptions.Conv2DOptionsAddFusedActivationFunction(new_builder, conv2d_options.FusedActivationFunction())
         return tflite.Conv2DOptions.Conv2DOptionsEnd(new_builder)
 
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DepthwiseConv2DOptions:
+
+        depthconv2d_option = tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions()
+        depthconv2d_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
+
+        tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsStart(new_builder)
+        tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddPadding(new_builder, depthconv2d_option.Padding())
+        tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideW(new_builder, depthconv2d_option.StrideW())
+        tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddStrideH(new_builder, depthconv2d_option.StrideH())
+        tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddDepthMultiplier(new_builder, depthconv2d_option.DepthMultiplier())
+        tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsAddFusedActivationFunction(new_builder, depthconv2d_option.FusedActivationFunction())
+        return tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptionsEnd(new_builder)
+
     if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().Pool2DOptions:
 
         pool2d_option = tflite.Pool2DOptions.Pool2DOptions()
@@ -263,6 +278,24 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t
         tflite.ConcatenationOptions.ConcatenationOptionsAddFusedActivationFunction(new_builder, concat_option.FusedActivationFunction())
         return tflite.ConcatenationOptions.ConcatenationOptionsEnd(new_builder)
 
+    if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().ReshapeOptions:
+
+        reshape_option = tflite.ReshapeOptions.ReshapeOptions()
+        reshape_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos)
+
+        shape_num = reshape_option.NewShapeLength()
+        if shape_num != 0:
+            tflite.ReshapeOptions.ReshapeOptionsStartNewShapeVector(new_builder, shape_num)
+            for new_shape_idx in reversed(range(shape_num)):
+                new_shape_val = reshape_option.NewShape(new_shape_idx)
+                new_builder.PrependInt32(new_shape_val)
+            new_shape = new_builder.EndVector(shape_num)
+
+        tflite.ReshapeOptions.ReshapeOptionsStart(new_builder)
+        if shape_num != 0:
+            tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape)
+        return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder)
+
     # Cannot handle builtin option type yet
     return 0