BuiltinOpcode = BuildEnumClassStrMap(tflite.BuiltinOperator.BuiltinOperator())
ActivationFunctionType = BuildEnumClassStrMap(
tflite.ActivationFunctionType.ActivationFunctionType())
+ BuiltinOptions = BuildEnumClassStrMap(tflite.BuiltinOptions.BuiltinOptions())
def GetStrTensorIndex(tensors):
return return_string
-# NOTE Currently not all builtin operations are supported
-def GetBuiltinOptions(options_type, options_table):
- import tflite.Conv2DOptions
- import tflite.DepthwiseConv2DOptions
- import tflite.Pool2DOptions
- import tflite.FullyConnectedOptions
- import tflite.SoftmaxOptions
- import tflite.ConcatenationOptions
- import tflite.ReshapeOptions
- import tflite.AddOptions
- import tflite.SubOptions
- import tflite.MulOptions
- import tflite.DivOptions
- import tflite.ResizeBilinearOptions
- import tflite.StridedSliceOptions
- import tflite.CastOptions
- import tflite.TopKV2Options
- import tflite.GatherOptions
-
- bo = tflite.BuiltinOptions.BuiltinOptions()
- bo_gen = {
- bo.Conv2DOptions: tflite.Conv2DOptions.Conv2DOptions,
- bo.DepthwiseConv2DOptions: tflite.DepthwiseConv2DOptions.DepthwiseConv2DOptions,
- bo.Pool2DOptions: tflite.Pool2DOptions.Pool2DOptions,
- bo.FullyConnectedOptions: tflite.FullyConnectedOptions.FullyConnectedOptions,
- bo.SoftmaxOptions: tflite.SoftmaxOptions.SoftmaxOptions,
- bo.ConcatenationOptions: tflite.ConcatenationOptions.ConcatenationOptions,
- bo.ReshapeOptions: tflite.ReshapeOptions.ReshapeOptions,
- bo.AddOptions: tflite.AddOptions.AddOptions,
- bo.SubOptions: tflite.SubOptions.SubOptions,
- bo.MulOptions: tflite.MulOptions.MulOptions,
- bo.DivOptions: tflite.DivOptions.DivOptions,
- bo.ResizeBilinearOptions: tflite.ResizeBilinearOptions.ResizeBilinearOptions,
- bo.StridedSliceOptions: tflite.StridedSliceOptions.StridedSliceOptions,
- bo.CastOptions: tflite.CastOptions.CastOptions,
- bo.TopKV2Options: tflite.TopKV2Options.TopKV2Options,
- bo.GatherOptions: tflite.GatherOptions.GatherOptions
- }
-
- options = bo_gen[options_type]()
- options.Init(options_table.Bytes, options_table.Pos)
- return options
+def GetAttribute(o, *args):
+ import functools
+ return functools.reduce(getattr, args, o)
+
+
+def BuildBuiltinOptionGen():
+ bo_gen = {}
+ for val_enum in EnumStrMaps.BuiltinOptions:
+ val_str = EnumStrMaps.BuiltinOptions[val_enum]
+ try:
+ # Dynamically import Builtin Option classes
+ # 0 (NONE) is the only exception that does not have no corresponding flatbuffer-generated class
+ module = __import__("tflite." + val_str)
+ bo_gen[val_enum] = GetAttribute(module, val_str, val_str)
+ except ImportError as e:
+ assert val_enum == 0 and val_str == "NONE"
+ return bo_gen
+
+
+class OptionLoader:
+ builtinOptionGen = BuildBuiltinOptionGen()
+
+ @staticmethod
+ def GetBuiltinOptions(options_type, options_table):
+ options = OptionLoader.builtinOptionGen[options_type]()
+ options.Init(options_table.Bytes, options_table.Pos)
+ return options
class Operator(object):
def PrintOptionInfo(self):
# FIXME: workaround for ops such as custom
try:
- options = GetBuiltinOptions(self.tf_operator.BuiltinOptionsType(),
- self.tf_operator.BuiltinOptions())
+ options = OptionLoader.GetBuiltinOptions(
+ self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions())
except KeyError:
return