From 13fb1cf9942839a13683f946d2058697163831e8 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Mon, 5 Feb 2018 11:23:40 -0800 Subject: [PATCH] Support parsing from text and fused op in contrib. PiperOrigin-RevId: 184558131 --- tensorflow/python/grappler/cost_analyzer_tool.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 61dc4e2..ac251f2 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -22,7 +22,7 @@ import argparse import sys from google.protobuf import text_format - +from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -43,7 +43,10 @@ def main(_): else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() - graph_def.ParseFromString(graph_file.read()) + if FLAGS.graphdef.endswith(".pbtxt"): + text_format.Merge(graph_file.read(), graph_def) + else: + graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() fetch = graph.get_operation_by_name(FLAGS.fetch) -- 2.7.4