From 8599f7c61c84b70adec659c8e0c99ac963ce0796 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Fri, 1 May 2020 17:02:24 -0700 Subject: [PATCH] [TFLite] Model importer to be compatible with tflite 2.1.0 (#5497) --- golang/sample/gen_mobilenet_lib.py | 8 ++++++-- python/tvm/relay/frontend/tflite.py | 12 +++++++++--- tests/python/frontend/tflite/test_forward.py | 8 +++++--- tutorials/frontend/from_tflite.py | 19 +++---------------- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/golang/sample/gen_mobilenet_lib.py b/golang/sample/gen_mobilenet_lib.py index 4f6a615..8becd07 100644 --- a/golang/sample/gen_mobilenet_lib.py +++ b/golang/sample/gen_mobilenet_lib.py @@ -18,7 +18,6 @@ import os from tvm import relay from tvm.contrib.download import download_testdata -import tflite.Model ################################################ @@ -49,7 +48,12 @@ model_file = os.path.join(model_dir, "mobilenet_v2_1.4_224.tflite") # get TFLite model from buffer tflite_model_buf = open(model_file, "rb").read() -tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) +try: + import tflite + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) +except AttributeError: + import tflite.Model + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) ############################## diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5c8bbfb..703ef9c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2524,7 +2524,7 @@ def from_tflite(model, shape_dict, dtype_dict): Parameters ---------- model: - tflite.Model.Model + tflite.Model or tflite.Model.Model (depending on tflite version) shape_dict : dict of str to int list/tuple Input shapes of the model. @@ -2541,12 +2541,18 @@ def from_tflite(model, shape_dict, dtype_dict): The parameter dict to be used by relay """ try: - import tflite.Model import tflite.SubGraph import tflite.BuiltinOperator except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(model, tflite.Model.Model) + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite + assert isinstance(model, tflite.Model) + except TypeError: + import tflite.Model + assert isinstance(model, tflite.Model.Model) # keep the same as tflite assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 75146c3..283d87d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -76,14 +76,16 @@ def get_real_image(im_height, im_width): def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None): """ Generic function to compile on relay and execute on tvm """ + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 try: import tflite.Model + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) except ImportError: raise ImportError("The tflite package must be installed") - # get TFLite model from buffer - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index 3273855..e01a4ec 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -21,25 +21,12 @@ Compile TFLite Models This article is an introductory tutorial to deploy TFLite models with Relay. -To get started, Flatbuffers and TFLite package needs to be installed as prerequisites. -A quick solution is to install Flatbuffers via pip +To get started, TFLite package needs to be installed as prerequisite. .. code-block:: bash - pip install flatbuffers --user - - -To install TFlite packages, you could use our prebuilt wheel: - -.. code-block:: bash - - # For python3: - wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whl - pip3 install -U tflite-1.13.1-py3-none-any.whl --user - - # For python2: - wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py2-none-any.whl - pip install -U tflite-1.13.1-py2-none-any.whl --user + # install tflite + pip install tflite=2.1.0 --user or you could generate TFLite package yourself. The steps are the following: -- 2.7.4