Support parsing from text and fused op in contrib.
authorYao Zhang <yaozhang@google.com>
Mon, 5 Feb 2018 19:23:40 +0000 (11:23 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 19:27:52 +0000 (11:27 -0800)
PiperOrigin-RevId: 184558131

tensorflow/python/grappler/cost_analyzer_tool.py

index 61dc4e2..ac251f2 100644 (file)
@@ -22,7 +22,7 @@ import argparse
 import sys
 
 from google.protobuf import text_format
-
+from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op  # pylint: disable=unused-import
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
@@ -43,7 +43,10 @@ def main(_):
   else:
     with gfile.GFile(FLAGS.graphdef) as graph_file:
       graph_def = graph_pb2.GraphDef()
-      graph_def.ParseFromString(graph_file.read())
+      if FLAGS.graphdef.endswith(".pbtxt"):
+        text_format.Merge(graph_file.read(), graph_def)
+      else:
+        graph_def.ParseFromString(graph_file.read())
       importer.import_graph_def(graph_def, name="")
       graph = ops.get_default_graph()
       fetch = graph.get_operation_by_name(FLAGS.fetch)