From: 윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 Date: Wed, 18 Jul 2018 07:33:35 +0000 (+0900) Subject: [tool] TFLITE file generator for Squeeze operations (#1948) X-Git-Tag: 0.2~440 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b44ef3360fc83db217d68e1b2e9bb9bf75d0c4d7;p=platform%2Fcore%2Fml%2Fnnfw.git [tool] TFLITE file generator for Squeeze operations (#1948) This file creates the following tflites: ``` /home/eric/models/squeeze/squeeze_2d.tflite /home/eric/models/squeeze/squeeze_4d_1.tflite /home/eric/models/squeeze/squeeze_4d_2.tflite /home/eric/models/squeeze/squeeze_4d_3.tflite /home/eric/models/squeeze/squeeze_4d_4.tflite ``` which are - squeeze_2d.tflite: `squeeze (Tensor(shape=[1, 3]))` - squeeze_4d_1.tflite: `squeeze (Tensor(shape=[1, 3, 2, 1]))` - squeeze_4d_2.tflite: `squeeze (Tensor(shape=[1, 3, 2, 1]), axis=[0]) # squeeze [1, 3, 2, 1] to [3, 2, 1]` - squeeze_4d_3.tflite: `squeeze (shape=Tensor([1, 3, 2, 1]), axis=[3]) # squeeze [1, 3, 2, 1] to [1, 3, 2]` - squeeze_4d_4.tflite: `squeeze (shape=Tensor([1, 3, 2, 1]), axis=[0, 3]) # squeeze [1, 3, 2, 1] to [3, 2]` Signed-off-by: Hyun Sik Yoon --- diff --git a/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py b/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py new file mode 100755 index 0000000..88b3dfc --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py @@ -0,0 +1,127 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for Squeeze + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'squeeze') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return { + "squeeze_2d": [base.Tensor([1, 3])], + "squeeze_4d_1": [base.Tensor([1, 3, 2, 1])], + # squeeze with axis + "squeeze_4d_2": [base.Tensor([1, 3, 2, 1]), + [0]], # squeeze [1, 3, 2, 1] to [3, 2, 1] + "squeeze_4d_3": [base.Tensor([1, 3, 2, 1]), + [3]], # squeeze [1, 3, 2, 1] to [1, 3, 2] + "squeeze_4d_4": [base.Tensor([1, 3, 2, 1]), + [0, 3]] # squeeze [1, 3, 2, 1] to [3, 2] + } + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + if len(test_case_tensor) == 2: + axis_tensor = test_case_tensor[1] + + # defining output node = x_input * y_input + # and input list + if len(test_case_tensor) == 1: + output_node = tf.squeeze(input=x_tensor, name=tc_name) # do not modify name + else: + output_node = tf.squeeze( + input=x_tensor, axis=axis_tensor, name=tc_name) # do not modify name + + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +''' +How to run +$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py +$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \ + tools/tensorflow_model_freezer/sample/name_of_this_file.py \ + ~/temp # directory where model files are saved +''' +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel()