[tool] TFLITE file generator for Squeeze operations (#1948)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 18 Jul 2018 07:33:35 +0000 (16:33 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Wed, 18 Jul 2018 07:33:35 +0000 (16:33 +0900)
This file creates the following tflites:
```
/home/eric/models/squeeze/squeeze_2d.tflite
/home/eric/models/squeeze/squeeze_4d_1.tflite
/home/eric/models/squeeze/squeeze_4d_2.tflite
/home/eric/models/squeeze/squeeze_4d_3.tflite
/home/eric/models/squeeze/squeeze_4d_4.tflite
```
which are
- squeeze_2d.tflite: `squeeze (Tensor(shape=[1, 3]))`
- squeeze_4d_1.tflite: `squeeze (Tensor(shape=[1, 3, 2, 1]))`
- squeeze_4d_2.tflite: `squeeze (Tensor(shape=[1, 3, 2, 1]), axis=[0]) # squeeze [1, 3, 2, 1] to [3, 2, 1]`
- squeeze_4d_3.tflite: `squeeze (shape=Tensor([1, 3, 2, 1]), axis=[3])  # squeeze [1, 3, 2, 1] to [1, 3, 2]`
- squeeze_4d_4.tflite: `squeeze (shape=Tensor([1, 3, 2, 1]), axis=[0, 3])  # squeeze [1, 3, 2, 1] to [3, 2]`

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py [new file with mode: 0755]

diff --git a/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py b/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py
new file mode 100755 (executable)
index 0000000..88b3dfc
--- /dev/null
@@ -0,0 +1,127 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+    '''
+    class to generate tflite files for Squeeze
+    '''
+
+    def __init__(self, path):
+        super(self.__class__, self).__init__(path)
+
+    def getOutputDirectory(self):
+        return os.path.join(self.root_output_path,
+                            'squeeze')  # the root path of generated files
+
+    def getTestCases(self):
+        '''
+        this returns a a hash containg test cases.
+        key of return hash is test case name and
+        value of return hash is test is a list of input tensor metadata.
+        test name (key of hash) is used as
+            - prefix of file name to be generated (don't use white space or special characters)
+            - output node name pf graph
+        '''
+        return {
+            "squeeze_2d": [base.Tensor([1, 3])],
+            "squeeze_4d_1": [base.Tensor([1, 3, 2, 1])],
+            # squeeze with axis
+            "squeeze_4d_2": [base.Tensor([1, 3, 2, 1]),
+                             [0]],  # squeeze [1, 3, 2, 1] to [3, 2, 1]
+            "squeeze_4d_3": [base.Tensor([1, 3, 2, 1]),
+                             [3]],  # squeeze [1, 3, 2, 1] to [1, 3, 2]
+            "squeeze_4d_4": [base.Tensor([1, 3, 2, 1]),
+                             [0, 3]]  # squeeze [1, 3, 2, 1] to [3, 2]
+        }
+
+    def buildModel(self, sess, test_case_tensor, tc_name):
+        '''
+        This method is called per test case (defined by getTestCases()).
+
+        keyword argument:
+        test_case_tensor -- test case tensor metadata
+            For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+            test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+        '''
+
+        input_list = []
+
+        # ------ modify below for your model FROM here -------#
+
+        x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+        if len(test_case_tensor) == 2:
+            axis_tensor = test_case_tensor[1]
+
+        # defining output node = x_input * y_input
+        # and input list
+        if len(test_case_tensor) == 1:
+            output_node = tf.squeeze(input=x_tensor, name=tc_name)  # do not modify name
+        else:
+            output_node = tf.squeeze(
+                input=x_tensor, axis=axis_tensor, name=tc_name)  # do not modify name
+
+        # ------ modify UNTIL here for your model -------#
+
+        # Note if don't have any CONST value, creating checkpoint file fails.
+        # The next lines insert such (CONST) to prevent such error.
+        # So, Graph.pb/pbtxt contains this garbage info,
+        # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+        garbage = tf.get_variable(
+            "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+        init_op = tf.global_variables_initializer()
+        garbage_value = [0]
+        sess.run(tf.assign(garbage, garbage_value))
+
+        sess.run(init_op)
+
+        # ------ modify appropriate return value  -------#
+
+        # returning (input_node_list, output_node_list)
+        return (input_list, [output_node])
+
+
+'''
+How to run
+$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py
+$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \
+      tools/tensorflow_model_freezer/sample/name_of_this_file.py \
+      ~/temp  # directory where model files are saved
+'''
+# --------
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser(
+        description='Converted Tensorflow model in python to frozen model.')
+    parser.add_argument(
+        "out_dir",
+        help=
+        "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+    )
+
+    args = parser.parse_args()
+    root_output_path = args.out_dir
+
+    Gen(root_output_path).createSaveFreezeModel()