Early version of freezer ( Tensorflow model (code level) to frozen model ) (#419)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Tue, 10 Apr 2018 06:47:50 +0000 (15:47 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Tue, 10 Apr 2018 06:47:50 +0000 (15:47 +0900)
This commit introduces a tool to convert Tensorflow model (in code level) to frozen model file.
This could be used in future to generate many frozen model files for test cases.
The following samples to generate frozen model are included.

- relu(wx+b)
- softmax(wx+b)

tools/tensorflow_model_freezer/freeze_programmed_tensor_graph.py [new file with mode: 0644]
tools/tensorflow_model_freezer/model_freezer_util.py [new file with mode: 0644]
tools/tensorflow_model_freezer/readme.md [new file with mode: 0644]

diff --git a/tools/tensorflow_model_freezer/freeze_programmed_tensor_graph.py b/tools/tensorflow_model_freezer/freeze_programmed_tensor_graph.py
new file mode 100644 (file)
index 0000000..f952cf4
--- /dev/null
@@ -0,0 +1,202 @@
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import model_freezer_util as util
+
+# --------
+class TensorFlowModelFreezer(object):
+
+    def __init__(self, path):
+        # files generated by child class will be stored under this path.
+        self.root_output_path = path
+
+    def createSaveFreezeModel(self):
+        ''' abstract class '''
+        raise NotImplementedError("please implement this")
+
+    def getOutputDirectory(self):
+        ''' abstract class
+            override method should return directory under self.root_output_path where all pb, pbtxt, checkpoing, tensorboard log are saved '''
+        raise NotImplementedError("please implement this")
+
+    def getTopNodeName(self):
+        ''' abstract class
+            override method should return top node of frozen graph '''
+        raise NotImplementedError("please implement this")
+
+    def saveRelatedFiles(self, sess):
+        ''' saves pb, pbtxt, chpt files and then freeze graph under top_node_name into directory '''
+
+        ''' produce pb, pbtxt, and ckpt files '''
+        (pb_path, pbtxt_path, checkpoint_path) = util.savePbAndCkpt(
+            sess, self.getOutputDirectory())
+
+        print("")
+        print("# Success. pb, pbtxt, checkpoint files are created")
+        print("  -> {}, {}, {}\n".format(pb_path, pbtxt_path, checkpoint_path))
+
+        '''
+        produce frozen_graph files
+        include only nodes below softmax node. nodes for gradient descent (reduce_mean, GradientDescentOptimizer, ...) will not be included
+        '''
+        (frozen_pb_path, frozen_pbtxt_path) = util.freezeGraph(
+            pb_path, checkpoint_path, self.getTopNodeName())
+
+        print("")
+        print("\n# Success. A frozen pb and pbtxt are created")
+        print("  -> {}, {}\n".format(frozen_pb_path, frozen_pbtxt_path))
+
+        return (pb_path, frozen_pb_path)
+
+    def generateTensorboardLog(self, pb_path, frozen_pb_path):
+        ''' generating tensorboard logs to compare original pb and frozen pb '''
+        tensorboardLogDir = util.generateTensorboardLog(
+            [pb_path, frozen_pb_path], ['original', 'frozen'], self.getOutputDirectory())
+
+        print("# You can view original graph and frozen graph with tensorboard.")
+        print(
+            "  Run the following: $ tensorboard --logdir={} ".format(tensorboardLogDir))
+
+# --------
+class SoftmaxTestModelFreezer (TensorFlowModelFreezer):
+    ''' class to create, save, and freeze Softmax(wx+b) '''
+
+    def __init__(self, path):
+        super(SoftmaxTestModelFreezer, self).__init__(path)
+
+    def getOutputDirectory(self):
+        assert self.root_output_path
+        return os.path.join(self.root_output_path, 'softmax')
+
+    def getTopNodeName(self):
+        return "SOFTMAX_TOP"
+
+    def createSaveFreezeModel(self):
+        ''' this sample product frozen graph for 'softmax(wx+b)' into 'product/softmax' directory '''
+
+        print("")
+        print("-------------------- freezing Softmax(wx+b) ------------------------")
+        print("# files will be saved into " + self.getOutputDirectory())
+        print("")
+
+        tf.reset_default_graph()  # without this, graph used previous session is reused : https://stackoverflow.com/questions/42706761/closing-session-in-tensorflow-doesnt-reset-graph
+
+        X = tf.placeholder("float", [None, 4])
+        Y = tf.placeholder("float", [None, 3])
+
+        W = tf.Variable(tf.random_normal([4, 3]), name='weight')
+        b = tf.Variable(tf.random_normal([3]), name='bias')
+
+        softmax_top = tf.nn.softmax(
+            tf.matmul(X, W) + b, name=self.getTopNodeName())
+
+        cost_function = tf.reduce_mean(-tf.reduce_sum(Y *
+                                                      tf.log(softmax_top), axis=1))
+
+        gd = tf.train.GradientDescentOptimizer(
+            learning_rate=0.11).minimize(cost_function)
+
+        with tf.Session() as sess:
+            sess.run(tf.global_variables_initializer())
+
+            print("# start to training")
+            x_input = [[1, 2, 3, 4],[5, 6, 7, 8],[1, 2, 3, 4],[5, 6, 7, 8],[1, 3, 5, 7],[7, 5, 3, 1],[1, 3, 2, 5],[4, 7, 6, 9]]
+            y_input = [[0, 0, 1],[0, 1, 0],[1, 0, 0],[0, 0, 1],[0, 1, 0],[1, 0, 0],[0, 0, 1],[0, 1, 0]]
+
+            for step in range(500):
+                sess.run(gd, feed_dict={X: x_input, Y: y_input})
+
+            print("# training is done. Now we have some weight and bias values")
+
+            ''' Now, save to proto buffer format and checkpoint '''
+            (pb_path, frozen_pb_path) = self.saveRelatedFiles(sess)
+
+        self.generateTensorboardLog(pb_path, frozen_pb_path)
+
+# --------
+class ReluTestModelFreezer(TensorFlowModelFreezer):
+    ''' class to create, save, and freeze relu(wx+b) '''
+
+    def __init__(self, path):
+        super(ReluTestModelFreezer, self).__init__(path)
+
+    def getOutputDirectory(self):
+        return os.path.join(self.root_output_path, 'relu')
+
+    def getTopNodeName(self):
+        return "RELU_TOP"
+
+    def createSaveFreezeModel(self):
+
+        print("")
+        print("-------------------- freezing Relu(wx+b) ------------------------")
+        print("# files will be saved into " + self.getOutputDirectory())
+        print("")
+
+        tf.reset_default_graph()  # without this, graph used previous session is reused : https://stackoverflow.com/questions/42706761/closing-session-in-tensorflow-doesnt-reset-graph
+
+        X = tf.placeholder(
+            tf.float32, shape=[
+                None, 3], name='X_placeholder')  # input
+        W = tf.get_variable("W_var", [3, 2], dtype=tf.float32, initializer=tf.zeros_initializer())
+        b = tf.get_variable("b_var", [2, 2], dtype=tf.float32, initializer=tf.zeros_initializer())
+
+        O = tf.nn.relu(tf.matmul(X, W) + b, name=self.getTopNodeName())  # activation / output
+
+        init_op = tf.global_variables_initializer()
+
+        with tf.Session() as sess:
+            sess.run(init_op)
+            # normally you would do some training here
+            # we will just assign something to W
+            sess.run(tf.assign(W, [[2, 4], [8, 16], [32, 64]]))
+            sess.run(tf.assign(b, [[128, 256], [512, 1024]]))
+
+            ''' Now, save to proto buffer format and checkpoint '''
+            (pb_path, frozen_pb_path) = self.saveRelatedFiles(sess)
+
+        self.generateTensorboardLog(pb_path, frozen_pb_path)
+
+# --------
+def main():
+
+    global root_output_path
+
+    if platform.python_version()[0] != '2':
+        print("python version must be 2.x")
+        sys.exit(0)
+
+    freezer = SoftmaxTestModelFreezer(root_output_path)
+    freezer.createSaveFreezeModel()
+
+    freezer = ReluTestModelFreezer(root_output_path)
+    freezer.createSaveFreezeModel()
+
+
+# --------
+# --------
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser(
+        description='Converted Tensorflow model in python to frozen model.')
+    parser.add_argument(
+        "out_dir",
+        help="directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored.")
+    parser.add_argument(
+        "-r",
+        action='store_true',
+        help="remove any existing out_dir. If -r is not provided, the directory is overwritten.")
+
+    args = parser.parse_args()
+    root_output_path = args.out_dir
+
+    if args.r == True:
+        if os.path.exists(root_output_path):
+            import shutil
+            shutil.rmtree(root_output_path)
+            print("# Removed directory " + root_output_path)
+
+    main()
diff --git a/tools/tensorflow_model_freezer/model_freezer_util.py b/tools/tensorflow_model_freezer/model_freezer_util.py
new file mode 100644 (file)
index 0000000..4f03e6a
--- /dev/null
@@ -0,0 +1,184 @@
+# utility for nncc
+
+import os
+import sys
+
+import tensorflow as tf
+from google.protobuf import text_format
+from tensorflow.python.platform import gfile
+from tensorflow.python.tools import freeze_graph
+from tensorflow.python.tools import optimize_for_inference_lib
+
+# --------
+def file_validity_check(fn, ext_must_be=''):
+    ''' check if file exist and file extention is corrent '''
+    if os.path.exists(fn) == False:
+        print("# error: file does not exist " + fn)
+        return False
+
+    if ext_must_be != '':
+        ext = os.path.splitext(fn)[1]
+        if ext[1:].lower() != ext_must_be:     # ext contains , e.g., '.pb'. need to exclud '.'
+            print("# error: wrong extension {}. Should be {} ".format(ext, ext_must_be))
+            return False
+
+    return True
+
+# --------
+def importGraphIntoSession(sess, filename, graphNameAfterImporting):
+    # this should be called inside
+    # with tf.Session() as sess:
+    assert sess
+    (_, _, ext) = splitDirFilenameExt(filename)
+    if (ext.lower() == 'pb'):
+        with gfile.FastGFile(filename, 'rb') as f:
+            graph_def = tf.GraphDef()
+            graph_def.ParseFromString(f.read())
+
+    elif (ext.lower() == 'pbtxt'):
+        with open(filename, 'r') as reader:
+            graph_def = tf.GraphDef()
+            text_format.Parse(reader.read(), graph_def)
+    else:
+        print("# Error: unknown extension - " + ext)
+
+    tf.import_graph_def(graph_def, name=graphNameAfterImporting)
+
+# --------
+def splitDirFilenameExt(path):
+    # in case of '/tmp/.ssh/my.key.dat'
+    # this returns ('/tmp/.ssh', 'my.key', 'dat')
+    directory = os.path.split(path)[0]
+    ext = os.path.splitext(path)[1][1:]   # remove '.', e.g., '.dat' -> 'dat'
+    filename = os.path.splitext(os.path.split(path)[1])[0]
+
+    return (directory, filename, ext)
+
+# --------
+def convertPbtxt2Pb(pbtxtPath):
+    ''' convert pbtxt file to pb file. e.g., /tmp/a.pbtxt --> /tmp/a.pb '''
+    with open(pbtxtPath) as f:
+        txt = f.read()
+
+    gdef = text_format.Parse(txt, tf.GraphDef())
+
+    (directory, filename, ext) = splitDirFilenameExt(pbtxtPath)
+
+    tf.train.write_graph(gdef, directory, filename + '.pb', as_text=False)
+
+    return os.path.join(directory, filename + '.pb')
+
+# --------
+def convertPb2Pbtxt(pbPath):
+    ''' convert pb file to pbtxt file. e.g., /tmp/a.pb --> /tmp/a.pbtxt '''
+
+    from tensorflow.python.platform import gfile
+
+    (directory, filename, ext) = splitDirFilenameExt(pbPath)
+
+    with gfile.FastGFile(pbPath, 'rb') as f:
+        content = f.read()
+
+    graph_def = tf.GraphDef()
+    graph_def.ParseFromString(content)
+    tf.import_graph_def(graph_def, name='')
+
+    tf.train.write_graph(graph_def, directory, filename + '.pbtxt', as_text=True)
+
+    return os.path.join(directory, filename + '.pbtxt')
+
+# --------
+def savePbAndCkpt(sess, directory):
+    ''' save files related to session's graph into directory.
+        - graph.pb : binary protocol buffer file
+        - graph.pbtxt : text format of protocol buffer file
+        - graph.ckpt.* : checkpoing files contains values of variables
+
+        returns (path of pb file, path of pbtxt file, path of ckpt files)
+    '''
+
+    tf.train.write_graph(sess.graph_def, directory, 'graph.pb', as_text=False)
+    tf.train.write_graph(sess.graph_def, directory, 'graph.pbtxt', as_text=True)
+
+    # save a checkpoint file, which will store the above assignment
+    saver = tf.train.Saver()
+    saver.save(sess, os.path.join(directory, 'graph.ckpt'))
+
+    return (os.path.join(directory, 'graph.pb'),
+            os.path.join(directory, 'graph.pbtxt'),
+            os.path.join(directory, 'graph.ckpt'))
+
+# --------
+def freezeGraph(input_graph_path, checkpoint_path, output_node_name):
+    ''' this function calls freeze_grapy.py of tensorflow and generates '*_frozen.pb' and '*_frozen.pbtxt'.
+
+        - input_graph_path : must be a path to pb file
+        - checkpoint_path  : path of *.ckpt, e.g., '/tmp/inception_v3/graph.ckpt'
+        - output_node_name : name of head(top) operation node
+        '''
+
+    input_saver_def_path = ""
+    input_binary = True
+
+    restore_op_name = "save/restore_all"
+    filename_tensor_name = "save/Const:0"
+    clear_devices = True
+
+    (directory, fn, ext) = splitDirFilenameExt(input_graph_path)
+    output_frozen_graph_path = os.path.join(directory, fn + '_frozen.pb')
+
+    if file_validity_check(input_graph_path, 'pb') == False:
+        print("Error: {} not found or not have pb extension".format(input_graph_path))
+        sys.exit(0)
+
+    import platform
+    if platform.python_version()[0] != '2':
+        print("python version must be 2.x")
+        sys.exit(0)
+
+    freeze_graph.freeze_graph(input_graph_path,
+                              input_saver_def_path,
+                              input_binary,
+                              checkpoint_path,
+                              output_node_name,
+                              restore_op_name,
+                              filename_tensor_name,
+                              output_frozen_graph_path,
+                              clear_devices,
+                              "")
+
+    pbtxtPath = convertPb2Pbtxt(output_frozen_graph_path)
+
+    return (output_frozen_graph_path, pbtxtPath)
+
+# --------
+def generateTensorboardLog(pbFiles, graphNames, directory):
+    ''' Generate logs for tensorboard. after calling this, graph(s) can be viewed inside tensorboard.
+        This function creates a new Session(), so call this outside of 'with Session():'
+
+        parameters:
+        - pbFiles: if multiple graphs needs to be shown, pass the list of pb (or pbtxt) files
+        - directory: parent directory of '/.tensorboard' directory where log files are saved
+
+        how to run tensorboard:
+              $ tensorboard --logdir={return value of this function == directory + '/.tensorboard'}
+    '''
+    assert len(pbFiles) == len(graphNames)
+
+    tf.reset_default_graph()  # without this, graph used previous session is reused : https://stackoverflow.com/questions/42706761/closing-session-in-tensorflow-doesnt-reset-graph
+    with tf.Session() as sess:
+
+        i = 0
+        for pbFile in pbFiles:
+            graphName = graphNames[i]
+            importGraphIntoSession(sess, pbFile, graphName)
+            print("# graph file {} is imported successfully with name = {}".format(pbFile, graphName))
+            i = i + 1
+
+    tbLogPath = os.path.join(directory, ".tensorboard")
+    train_writer = tf.summary.FileWriter(tbLogPath)
+    train_writer.add_graph(sess.graph)
+    train_writer.flush()
+    train_writer.close()
+
+    return tbLogPath
diff --git a/tools/tensorflow_model_freezer/readme.md b/tools/tensorflow_model_freezer/readme.md
new file mode 100644 (file)
index 0000000..9a0c0f9
--- /dev/null
@@ -0,0 +1,20 @@
+## What this tool is about
+
+This tool (python module) does the following:
+
+- Programmer can add Tensorflow code to build a specific graph. --(1)
+- This tool generates the graph into graph def (pb, pbtxt) and checkpoint files. -- (2)
+- Files from (2) is freezed and frozen graph file is generated. -- (3) 
+- Additionally, the visual structure of original graph (2) and frozen graph (3) are viewed using Tensorboard.
+
+## How to use
+- This tool is tested with Python 2.7 with Tensorflow 1.6
+- Programmer can add his/her graph into freeze_programmed_tensor_graph.py
+  - Add code as a subclass of `TensorFlowModelFreezer`
+  - Currently two sample subclasses are written: 
+    - ReluTestModelFreezer: Relu(wx+b) 
+    - SoftmaxTestModelFreezer: Softmax(wx+b)
+- Run `$ (nnfw root) python tools/tensorflow_model_freezer/freeze_programmed_tensor_graph.py -r output_directory` and model files will be generated under output_directory.
+  - `-r` means the output_directory will be removed before generating files. If `-r` is not provided, the directory will be overwritten.
+- To launch Tensorboard, run, e.g., `$ tensorboard --logdir=output_directory/relu/.tensorbrd` 
+