From 94db5730c4d62be12a0ddad4b12b0b5d3df1e037 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9C=A4=EC=A7=80=EC=98=81/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Tue, 11 Dec 2018 14:15:19 +0900 Subject: [PATCH] Use TFLiteConverter instead of deprecated API in tensorflow_model_freezer (#3917) Add prerequires section in readme Use tflite v1.12 tool to support the extended operators Fix some tc errors Signed-off-by: Jiyoung Yun --- tools/tensorflow_model_freezer/base_freezer.py | 5 +++-- tools/tensorflow_model_freezer/readme.md | 10 ++++++++++ tools/tensorflow_model_freezer/sample/MUL_gen.py | 4 ++-- tools/tensorflow_model_freezer/sample/TOPK_gen.py | 7 +++++-- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tools/tensorflow_model_freezer/base_freezer.py b/tools/tensorflow_model_freezer/base_freezer.py index ccfd811..96d0117 100644 --- a/tools/tensorflow_model_freezer/base_freezer.py +++ b/tools/tensorflow_model_freezer/base_freezer.py @@ -129,8 +129,9 @@ class BaseFreezer(object): util.importGraphIntoSession(sess, frozen_pb_path, "") try: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, input_node_list, - output_node_list) + converter = tf.contrib.lite.TFLiteConverter.from_session( + sess, input_node_list, output_node_list) + tflite_model = converter.convert() open(tflite_path, "wb").write(tflite_model) print("# 3. TOCO : Created TFLITE file :\n\t-{}\n".format(tflite_path)) except Exception: diff --git a/tools/tensorflow_model_freezer/readme.md b/tools/tensorflow_model_freezer/readme.md index f627f11..348e756 100644 --- a/tools/tensorflow_model_freezer/readme.md +++ b/tools/tensorflow_model_freezer/readme.md @@ -1,3 +1,13 @@ +## Prerequisites + +The scripts here use TensorFlow's tools, so you need an environment to run TensorFlow. Running the scripts within this tutorial requires: + +* Install [TensorFlow](https://www.tensorflow.org/install/) v1.12 or later + * Use pip + ``` + $ pip install tensorflow==1.12 + ``` + ## What this tool is about This tool generaes the following files: diff --git a/tools/tensorflow_model_freezer/sample/MUL_gen.py b/tools/tensorflow_model_freezer/sample/MUL_gen.py index f2a9254..b85f72e 100755 --- a/tools/tensorflow_model_freezer/sample/MUL_gen.py +++ b/tools/tensorflow_model_freezer/sample/MUL_gen.py @@ -59,8 +59,8 @@ class Gen(base.BaseFreezer): "mul_1d_scalarConst": [base.Tensor([5]), base.Tensor([], const_val=1.1)], # mul by scalar "mul_2d_scalarConst": [base.Tensor([5, 3]), - base.Tensor([], const_val=1.1)], - "mul_1d_scalar": [base.Tensor([5, 3]), base.Tensor([])] + base.Tensor([], const_val=1.1)] + # "mul_2d_scalar": [base.Tensor([5, 3]), base.Tensor([])] # not support scalar input } def buildModel(self, sess, test_case_tensor, tc_name): diff --git a/tools/tensorflow_model_freezer/sample/TOPK_gen.py b/tools/tensorflow_model_freezer/sample/TOPK_gen.py index 0c16d5b..170472f 100755 --- a/tools/tensorflow_model_freezer/sample/TOPK_gen.py +++ b/tools/tensorflow_model_freezer/sample/TOPK_gen.py @@ -63,6 +63,7 @@ class Gen(base.BaseFreezer): ''' input_list = [] + output_list = [] # ------ modify below for your model FROM here -------# @@ -70,11 +71,13 @@ class Gen(base.BaseFreezer): y_tensor = self.createTFInput(test_case_tensor[1], input_list) # defining output node and input list - output_node = tf.nn.top_k( + values_op, indices_op = tf.nn.top_k( x_tensor, y_tensor, # add your input here name=tc_name) # do not modify name + output_list.append(values_op) + output_list.append(indices_op) # ------ modify UNTIL here for your model -------# # Note if don't have any CONST value, creating checkpoint file fails. @@ -92,7 +95,7 @@ class Gen(base.BaseFreezer): # ------ modify appropriate return value -------# # returning (input_node_list, output_node_list) - return (input_list, [output_node]) + return (input_list, output_list) ''' -- 2.7.4