From 688cb2d1800ab9eddb04e9ac30ec236b99fc8695 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=98=A4=ED=98=95=EC=84=9D/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 18 May 2018 10:06:32 +0900 Subject: [PATCH] TFLiteFile tool: Fix shape bug, option (#1265) Fix tensor shape bug when tensor is scalar Fix generating option for new supported operators Signed-off-by: Hyeongseok Oh --- tools/tflitefile_tool/select_operator.py | 94 ++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 4 deletions(-) diff --git a/tools/tflitefile_tool/select_operator.py b/tools/tflitefile_tool/select_operator.py index fa7a268..c235d17 100755 --- a/tools/tflitefile_tool/select_operator.py +++ b/tools/tflitefile_tool/select_operator.py @@ -138,11 +138,11 @@ def GenerateTensor(new_builder, selected_tensor, used_buffers_dic): # Create shape vector for tensor shape_num = selected_tensor.ShapeLength() + tflite.Tensor.TensorStartShapeVector(new_builder, shape_num) if shape_num != 0: - tflite.Tensor.TensorStartShapeVector(new_builder, shape_num) for shape_idx in reversed(range(shape_num)): new_builder.PrependInt32(selected_tensor.Shape(shape_idx)) - new_shape = new_builder.EndVector(shape_num) + new_shape = new_builder.EndVector(shape_num) # Create tensor_type tensor_type = selected_tensor.Type() @@ -163,8 +163,7 @@ def GenerateTensor(new_builder, selected_tensor, used_buffers_dic): # Create tensor tflite.Tensor.TensorStart(new_builder) - if shape_num != 0: - tflite.Tensor.TensorAddShape(new_builder, new_shape) + tflite.Tensor.TensorAddShape(new_builder, new_shape) tflite.Tensor.TensorAddType(new_builder, tensor_type) tflite.Tensor.TensorAddBuffer(new_builder, new_buffer_idx) if name_string != "": @@ -206,6 +205,13 @@ import tflite.FullyConnectedOptions import tflite.SoftmaxOptions import tflite.ConcatenationOptions import tflite.ReshapeOptions +import tflite.AddOptions +import tflite.SubOptions +import tflite.MulOptions +import tflite.DivOptions +import tflite.ResizeBilinearOptions +import tflite.StridedSliceOptions +import tflite.CastOptions def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_type): @@ -314,6 +320,86 @@ def GenerateBuiltinOption(new_builder, selected_builtin_option, builtin_option_t tflite.ReshapeOptions.ReshapeOptionsAddNewShape(new_builder, new_shape) return tflite.ReshapeOptions.ReshapeOptionsEnd(new_builder) + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().AddOptions: + + add_option = tflite.AddOptions.AddOptions() + add_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.AddOptions.AddOptionsStart(new_builder) + tflite.AddOptions.AddOptionsAddFusedActivationFunction( + new_builder, add_option.FusedActivationFunction()) + return tflite.AddOptions.AddOptionsEnd(new_builder) + + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().SubOptions: + + sub_option = tflite.SubOptions.SubOptions() + sub_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.SubOptions.SubOptionsStart(new_builder) + tflite.SubOptions.SubOptionsAddFusedActivationFunction( + new_builder, sub_option.FusedActivationFunction()) + return tflite.SubOptions.SubOptionsEnd(new_builder) + + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().MulOptions: + + mul_option = tflite.MulOptions.MulOptions() + mul_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.MulOptions.MulOptionsStart(new_builder) + tflite.MulOptions.MulOptionsAddFusedActivationFunction( + new_builder, mul_option.FusedActivationFunction()) + return tflite.MulOptions.MulOptionsEnd(new_builder) + + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().DivOptions: + + div_option = tflite.DivOptions.DivOptions() + div_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.DivOptions.DivOptionsStart(new_builder) + tflite.DivOptions.DivOptionsAddFusedActivationFunction( + new_builder, div_option.FusedActivationFunction()) + return tflite.DivOptions.DivOptionsEnd(new_builder) + + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions( + ).ResizeBilinearOptions: + + resize_bilinear_option = tflite.ResizeBilinearOptions.ResizeBilinearOptions() + resize_bilinear_option.Init(selected_builtin_option.Bytes, + selected_builtin_option.Pos) + + tflite.ResizeBilinearOptions.ResizeBilinearOptionsStart(new_builder) + tflite.ResizeBilinearOptions.ResizeBilinearOptionsAddAlignCorners( + new_builder, resize_bilinear_option.AlignCorners()) + return tflite.ResizeBilinearOptions.ResizeBilinearOptionsEnd(new_builder) + + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().StridedSliceOptions: + + stride_slice_option = tflite.StridedSliceOptions.StridedSliceOptions(new_builder) + stride_slice_option.Init(selected_builtin_option.Bytes, + selected_builtin_option.Pos) + + tflite.StridedSliceOptions.StridedSliceOptionsStart(new_builder) + tflite.StridedSliceOptions.StridedSliceOptionsAddBeginMask( + new_builder, stride_slice_option.BeginMask()) + tflite.StridedSliceOptions.StridedSliceOptionsAddEndMask( + new_builder, stride_slice_option.EndMask()) + tflite.StridedSliceOptions.StridedSliceOptionsAddEllipsisMask( + new_builder, stride_slice_option.EllipsisMask()) + tflite.StridedSliceOptions.StridedSliceOptionsAddNewAxisMask( + new_builder, stride_slice_option.NewAxisMask()) + tflite.StridedSliceOptions.StridedSliceOptionsAddShrinkAxisMask( + new_builder, stride_slice_option.ShrinkAxisMask()) + + return tflite.StridedSliceOptions.StridedSliceOptionsEnd(new_builder) + + if builtin_option_type == tflite.BuiltinOptions.BuiltinOptions().CastOptions: + + cast_option = tflite.CastOptions.CastOptions(new_builder) + cast_option.Init(selected_builtin_option.Bytes, selected_builtin_option.Pos) + + tflite.CastOptions.CastOptionsStart(new_builder) + return tflite.CastOptions.CastOptionsEnd(new_builder) + # Cannot handle builtin option type yet return 0 -- 2.7.4