[model_parser] Read BuiltinOptions dynamically (#2094)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Mon, 30 Jul 2018 04:19:36 +0000 (13:19 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 30 Jul 2018 04:19:36 +0000 (13:19 +0900)
This commit removes the hardcoded implementation of BuiltinOptions
Generator table. It now reads the list of Builtin Options from
`BuiltinOptions.py` and automatically imports the corresponding
classes and build the generator map dynamically. All these were
possible with reflection.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
tools/tflitefile_tool/operator_wrapping.py

index 780fa20..1b7f55a 100755 (executable)
@@ -27,6 +27,7 @@ class EnumStrMaps():
     BuiltinOpcode = BuildEnumClassStrMap(tflite.BuiltinOperator.BuiltinOperator())
     ActivationFunctionType = BuildEnumClassStrMap(
         tflite.ActivationFunctionType.ActivationFunctionType())
+    BuiltinOptions = BuildEnumClassStrMap(tflite.BuiltinOptions.BuiltinOptions())
 
 
 def GetStrTensorIndex(tensors):
@@ -39,48 +40,33 @@ def GetStrTensorIndex(tensors):
     return return_string
 
 
-# NOTE Currently not all builtin operations are supported
-def GetBuiltinOptions(options_type, options_table):
-    import tflite.Conv2DOptions
-    import tflite.DepthwiseConv2DOptions
-    import tflite.Pool2DOptions
-    import tflite.FullyConnectedOptions
-    import tflite.SoftmaxOptions
-    import tflite.ConcatenationOptions
-    import tflite.ReshapeOptions
-    import tflite.AddOptions
-    import tflite.SubOptions
-    import tflite.MulOptions
-    import tflite.DivOptions
-    import tflite.ResizeBilinearOptions
-    import tflite.StridedSliceOptions
-    import tflite.CastOptions
-    import tflite.TopKV2Options
-    import tflite.GatherOptions
-
-    bo = tflite.BuiltinOptions.BuiltinOptions()
-    bo_gen = {
-        bo.Conv2DOptions: tflite.Conv2DOptions.Conv2DOptions,
-        bo.DepthwiseConv2DOptions: tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions,
-        bo.Pool2DOptions: tflite.Pool2DOptions.Pool2DOptions,
-        bo.FullyConnectedOptions: tflite.FullyConnectedOptions.FullyConnectedOptions,
-        bo.SoftmaxOptions: tflite.SoftmaxOptions.SoftmaxOptions,
-        bo.ConcatenationOptions: tflite.ConcatenationOptions.ConcatenationOptions,
-        bo.ReshapeOptions: tflite.ReshapeOptions.ReshapeOptions,
-        bo.AddOptions: tflite.AddOptions.AddOptions,
-        bo.SubOptions: tflite.SubOptions.SubOptions,
-        bo.MulOptions: tflite.MulOptions.MulOptions,
-        bo.DivOptions: tflite.DivOptions.DivOptions,
-        bo.ResizeBilinearOptions: tflite.ResizeBilinearOptions.ResizeBilinearOptions,
-        bo.StridedSliceOptions: tflite.StridedSliceOptions.StridedSliceOptions,
-        bo.CastOptions: tflite.CastOptions.CastOptions,
-        bo.TopKV2Options: tflite.TopKV2Options.TopKV2Options,
-        bo.GatherOptions: tflite.GatherOptions.GatherOptions
-    }
-
-    options = bo_gen[options_type]()
-    options.Init(options_table.Bytes, options_table.Pos)
-    return options
+def GetAttribute(o, *args):
+    import functools
+    return functools.reduce(getattr, args, o)
+
+
+def BuildBuiltinOptionGen():
+    bo_gen = {}
+    for val_enum in EnumStrMaps.BuiltinOptions:
+        val_str = EnumStrMaps.BuiltinOptions[val_enum]
+        try:
+            # Dynamically import Builtin Option classes
+            # 0 (NONE) is the only exception that does not have no corresponding flatbuffer-generated class
+            module = __import__("tflite." + val_str)
+            bo_gen[val_enum] = GetAttribute(module, val_str, val_str)
+        except ImportError as e:
+            assert val_enum == 0 and val_str == "NONE"
+    return bo_gen
+
+
+class OptionLoader:
+    builtinOptionGen = BuildBuiltinOptionGen()
+
+    @staticmethod
+    def GetBuiltinOptions(options_type, options_table):
+        options = OptionLoader.builtinOptionGen[options_type]()
+        options.Init(options_table.Bytes, options_table.Pos)
+        return options
 
 
 class Operator(object):
@@ -119,8 +105,8 @@ class Operator(object):
     def PrintOptionInfo(self):
         # FIXME: workaround for ops such as custom
         try:
-            options = GetBuiltinOptions(self.tf_operator.BuiltinOptionsType(),
-                                        self.tf_operator.BuiltinOptions())
+            options = OptionLoader.GetBuiltinOptions(
+                self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions())
         except KeyError:
             return