Simplify file reading and support SavedModel.
authorYao Zhang <yaozhang@google.com>
Thu, 3 May 2018 19:19:01 +0000 (12:19 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 3 May 2018 20:37:14 +0000 (13:37 -0700)
PiperOrigin-RevId: 195291836

tensorflow/python/grappler/cost_analyzer_tool.py

index 0853db2..e6229e1 100644 (file)
@@ -21,11 +21,13 @@ from __future__ import print_function
 import argparse
 import sys
 
+from google.protobuf import message
 from google.protobuf import text_format
 from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op  # pylint: disable=unused-import
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.core.protobuf import saved_model_pb2
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.grappler import cost_analyzer
@@ -37,33 +39,42 @@ from tensorflow.python.training import saver
 
 def get_metagraph():
   """Constructs and returns a MetaGraphDef from the input file."""
-  if FLAGS.metagraphdef:
-    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
-      metagraph = meta_graph_pb2.MetaGraphDef()
-      if FLAGS.metagraphdef.endswith(".pbtxt"):
-        text_format.Merge(meta_file.read(), metagraph)
-      else:
-        metagraph.ParseFromString(meta_file.read())
-    if FLAGS.fetch is not None:
-      fetch_collection = meta_graph_pb2.CollectionDef()
-      for fetch in FLAGS.fetch.split(","):
-        fetch_collection.node_list.value.append(fetch)
-      metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
-  else:
-    with gfile.GFile(FLAGS.graphdef) as graph_file:
-      graph_def = graph_pb2.GraphDef()
-      if FLAGS.graphdef.endswith(".pbtxt"):
-        text_format.Merge(graph_file.read(), graph_def)
-      else:
-        graph_def.ParseFromString(graph_file.read())
-      importer.import_graph_def(graph_def, name="")
-      graph = ops.get_default_graph()
-      for fetch in FLAGS.fetch.split(","):
-        fetch_op = graph.get_operation_by_name(fetch)
-        graph.add_to_collection("train_op", fetch_op)
-      metagraph = saver.export_meta_graph(
-          graph_def=graph.as_graph_def(), graph=graph)
-  return metagraph
+  with gfile.GFile(FLAGS.input) as input_file:
+    input_data = input_file.read()
+    try:
+      saved_model = saved_model_pb2.SavedModel()
+      text_format.Merge(input_data, saved_model)
+      meta_graph = saved_model.meta_graphs[0]
+    except text_format.ParseError:
+      try:
+        saved_model.ParseFromString(input_data)
+        meta_graph = saved_model.meta_graphs[0]
+      except message.DecodeError:
+        try:
+          meta_graph = meta_graph_pb2.MetaGraphDef()
+          text_format.Merge(input_data, meta_graph)
+        except text_format.ParseError:
+          try:
+            meta_graph.ParseFromString(input_data)
+          except message.DecodeError:
+            try:
+              graph_def = graph_pb2.GraphDef()
+              text_format.Merge(input_data, graph_def)
+            except text_format.ParseError:
+              try:
+                graph_def.ParseFromString(input_data)
+              except message.DecodeError:
+                raise ValueError("Invalid input file.")
+            importer.import_graph_def(graph_def, name="")
+            graph = ops.get_default_graph()
+            meta_graph = saver.export_meta_graph(
+                graph_def=graph.as_graph_def(), graph=graph)
+  if FLAGS.fetch is not None:
+    fetch_collection = meta_graph_pb2.CollectionDef()
+    for fetch in FLAGS.fetch.split(","):
+      fetch_collection.node_list.value.append(fetch)
+    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
+  return meta_graph
 
 
 def main(_):
@@ -85,15 +96,11 @@ def main(_):
 if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   parser.add_argument(
-      "--metagraphdef",
+      "--input",
       type=str,
       default=None,
-      help="Input .meta MetaGraphDef file path.")
-  parser.add_argument(
-      "--graphdef",
-      type=str,
-      default=None,
-      help="Input .pb GraphDef file path.")
+      help="Input file path. Accept SavedModel, MetaGraphDef, and GraphDef in "
+      "either binary or text format.")
   parser.add_argument(
       "--fetch",
       type=str,