From 6d18dce4a4c84f75b761c3265a7418a54edcfe64 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=ED=95=9C=EC=A2=85/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Mon, 30 Jul 2018 13:19:36 +0900 Subject: [PATCH] [model_parser] Read BuiltinOptions dynamically (#2094) This commit removes the hardcoded implementation of BuiltinOptions Generator table. It now reads the list of Builtin Options from `BuiltinOptions.py` and automatically imports the corresponding classes and build the generator map dynamically. All these were possible with reflection. Signed-off-by: Hanjoung Lee --- tools/tflitefile_tool/operator_wrapping.py | 74 ++++++++++++------------------ 1 file changed, 30 insertions(+), 44 deletions(-) diff --git a/tools/tflitefile_tool/operator_wrapping.py b/tools/tflitefile_tool/operator_wrapping.py index 780fa20..1b7f55a 100755 --- a/tools/tflitefile_tool/operator_wrapping.py +++ b/tools/tflitefile_tool/operator_wrapping.py @@ -27,6 +27,7 @@ class EnumStrMaps(): BuiltinOpcode = BuildEnumClassStrMap(tflite.BuiltinOperator.BuiltinOperator()) ActivationFunctionType = BuildEnumClassStrMap( tflite.ActivationFunctionType.ActivationFunctionType()) + BuiltinOptions = BuildEnumClassStrMap(tflite.BuiltinOptions.BuiltinOptions()) def GetStrTensorIndex(tensors): @@ -39,48 +40,33 @@ def GetStrTensorIndex(tensors): return return_string -# NOTE Currently not all builtin operations are supported -def GetBuiltinOptions(options_type, options_table): - import tflite.Conv2DOptions - import tflite.DepthwiseConv2DOptions - import tflite.Pool2DOptions - import tflite.FullyConnectedOptions - import tflite.SoftmaxOptions - import tflite.ConcatenationOptions - import tflite.ReshapeOptions - import tflite.AddOptions - import tflite.SubOptions - import tflite.MulOptions - import tflite.DivOptions - import tflite.ResizeBilinearOptions - import tflite.StridedSliceOptions - import tflite.CastOptions - import tflite.TopKV2Options - import tflite.GatherOptions - - bo = tflite.BuiltinOptions.BuiltinOptions() - bo_gen = { - bo.Conv2DOptions: tflite.Conv2DOptions.Conv2DOptions, - bo.DepthwiseConv2DOptions: tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions, - bo.Pool2DOptions: tflite.Pool2DOptions.Pool2DOptions, - bo.FullyConnectedOptions: tflite.FullyConnectedOptions.FullyConnectedOptions, - bo.SoftmaxOptions: tflite.SoftmaxOptions.SoftmaxOptions, - bo.ConcatenationOptions: tflite.ConcatenationOptions.ConcatenationOptions, - bo.ReshapeOptions: tflite.ReshapeOptions.ReshapeOptions, - bo.AddOptions: tflite.AddOptions.AddOptions, - bo.SubOptions: tflite.SubOptions.SubOptions, - bo.MulOptions: tflite.MulOptions.MulOptions, - bo.DivOptions: tflite.DivOptions.DivOptions, - bo.ResizeBilinearOptions: tflite.ResizeBilinearOptions.ResizeBilinearOptions, - bo.StridedSliceOptions: tflite.StridedSliceOptions.StridedSliceOptions, - bo.CastOptions: tflite.CastOptions.CastOptions, - bo.TopKV2Options: tflite.TopKV2Options.TopKV2Options, - bo.GatherOptions: tflite.GatherOptions.GatherOptions - } - - options = bo_gen[options_type]() - options.Init(options_table.Bytes, options_table.Pos) - return options +def GetAttribute(o, *args): + import functools + return functools.reduce(getattr, args, o) + + +def BuildBuiltinOptionGen(): + bo_gen = {} + for val_enum in EnumStrMaps.BuiltinOptions: + val_str = EnumStrMaps.BuiltinOptions[val_enum] + try: + # Dynamically import Builtin Option classes + # 0 (NONE) is the only exception that does not have no corresponding flatbuffer-generated class + module = __import__("tflite." + val_str) + bo_gen[val_enum] = GetAttribute(module, val_str, val_str) + except ImportError as e: + assert val_enum == 0 and val_str == "NONE" + return bo_gen + + +class OptionLoader: + builtinOptionGen = BuildBuiltinOptionGen() + + @staticmethod + def GetBuiltinOptions(options_type, options_table): + options = OptionLoader.builtinOptionGen[options_type]() + options.Init(options_table.Bytes, options_table.Pos) + return options class Operator(object): @@ -119,8 +105,8 @@ class Operator(object): def PrintOptionInfo(self): # FIXME: workaround for ops such as custom try: - options = GetBuiltinOptions(self.tf_operator.BuiltinOptionsType(), - self.tf_operator.BuiltinOptions()) + options = OptionLoader.GetBuiltinOptions( + self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions()) except KeyError: return -- 2.7.4