Use TFLiteConverter instead of deprecated API in tensorflow_model_freezer (#3917)
author윤지영/동작제어Lab(SR)/Engineer/삼성전자 <jy910.yun@samsung.com>
Tue, 11 Dec 2018 05:15:19 +0000 (14:15 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 11 Dec 2018 05:15:19 +0000 (14:15 +0900)
Add prerequires section in readme
Use tflite v1.12 tool to support the extended operators
Fix some tc errors

Signed-off-by: Jiyoung Yun <jy910.yun@samsung.com>
tools/tensorflow_model_freezer/base_freezer.py
tools/tensorflow_model_freezer/readme.md
tools/tensorflow_model_freezer/sample/MUL_gen.py
tools/tensorflow_model_freezer/sample/TOPK_gen.py

index ccfd811..96d0117 100644 (file)
@@ -129,8 +129,9 @@ class BaseFreezer(object):
 
         util.importGraphIntoSession(sess, frozen_pb_path, "")
         try:
-            tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, input_node_list,
-                                                        output_node_list)
+            converter = tf.contrib.lite.TFLiteConverter.from_session(
+                sess, input_node_list, output_node_list)
+            tflite_model = converter.convert()
             open(tflite_path, "wb").write(tflite_model)
             print("# 3. TOCO : Created TFLITE file :\n\t-{}\n".format(tflite_path))
         except Exception:
index f627f11..348e756 100644 (file)
@@ -1,3 +1,13 @@
+## Prerequisites
+
+The scripts here use TensorFlow's tools, so you need an environment to run TensorFlow. Running the scripts within this tutorial requires:
+
+* Install [TensorFlow](https://www.tensorflow.org/install/) v1.12 or later
+    * Use pip
+    ```
+    $ pip install tensorflow==1.12
+    ```
+
 ## What this tool is about
 
 This tool generaes the following files:
index f2a9254..b85f72e 100755 (executable)
@@ -59,8 +59,8 @@ class Gen(base.BaseFreezer):
             "mul_1d_scalarConst": [base.Tensor([5]),
                                    base.Tensor([], const_val=1.1)],  # mul by scalar
             "mul_2d_scalarConst": [base.Tensor([5, 3]),
-                                   base.Tensor([], const_val=1.1)],
-            "mul_1d_scalar": [base.Tensor([5, 3]), base.Tensor([])]
+                                   base.Tensor([], const_val=1.1)]
+            #            "mul_2d_scalar": [base.Tensor([5, 3]), base.Tensor([])]  # not support scalar input
         }
 
     def buildModel(self, sess, test_case_tensor, tc_name):
index 0c16d5b..170472f 100755 (executable)
@@ -63,6 +63,7 @@ class Gen(base.BaseFreezer):
         '''
 
         input_list = []
+        output_list = []
 
         # ------ modify below for your model FROM here -------#
 
@@ -70,11 +71,13 @@ class Gen(base.BaseFreezer):
         y_tensor = self.createTFInput(test_case_tensor[1], input_list)
 
         # defining output node and input list
-        output_node = tf.nn.top_k(
+        values_op, indices_op = tf.nn.top_k(
             x_tensor,
             y_tensor,  # add your input here
             name=tc_name)  # do not modify name
 
+        output_list.append(values_op)
+        output_list.append(indices_op)
         # ------ modify UNTIL here for your model -------#
 
         # Note if don't have any CONST value, creating checkpoint file fails.
@@ -92,7 +95,7 @@ class Gen(base.BaseFreezer):
         # ------ modify appropriate return value  -------#
 
         # returning (input_node_list, output_node_list)
-        return (input_list, [output_node])
+        return (input_list, output_list)
 
 
 '''