From e5854637cc3f8099586f18ed144fd6d4f90a6fc7 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Thu, 3 May 2018 12:19:01 -0700 Subject: [PATCH] Simplify file reading and support SavedModel. PiperOrigin-RevId: 195291836 --- tensorflow/python/grappler/cost_analyzer_tool.py | 75 +++++++++++++----------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 0853db2..e6229e1 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -21,11 +21,13 @@ from __future__ import print_function import argparse import sys +from google.protobuf import message from google.protobuf import text_format from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.grappler import cost_analyzer @@ -37,33 +39,42 @@ from tensorflow.python.training import saver def get_metagraph(): """Constructs and returns a MetaGraphDef from the input file.""" - if FLAGS.metagraphdef: - with gfile.GFile(FLAGS.metagraphdef) as meta_file: - metagraph = meta_graph_pb2.MetaGraphDef() - if FLAGS.metagraphdef.endswith(".pbtxt"): - text_format.Merge(meta_file.read(), metagraph) - else: - metagraph.ParseFromString(meta_file.read()) - if FLAGS.fetch is not None: - fetch_collection = meta_graph_pb2.CollectionDef() - for fetch in FLAGS.fetch.split(","): - fetch_collection.node_list.value.append(fetch) - metagraph.collection_def["train_op"].CopyFrom(fetch_collection) - else: - with gfile.GFile(FLAGS.graphdef) as graph_file: - graph_def = graph_pb2.GraphDef() - if FLAGS.graphdef.endswith(".pbtxt"): - text_format.Merge(graph_file.read(), graph_def) - else: - graph_def.ParseFromString(graph_file.read()) - importer.import_graph_def(graph_def, name="") - graph = ops.get_default_graph() - for fetch in FLAGS.fetch.split(","): - fetch_op = graph.get_operation_by_name(fetch) - graph.add_to_collection("train_op", fetch_op) - metagraph = saver.export_meta_graph( - graph_def=graph.as_graph_def(), graph=graph) - return metagraph + with gfile.GFile(FLAGS.input) as input_file: + input_data = input_file.read() + try: + saved_model = saved_model_pb2.SavedModel() + text_format.Merge(input_data, saved_model) + meta_graph = saved_model.meta_graphs[0] + except text_format.ParseError: + try: + saved_model.ParseFromString(input_data) + meta_graph = saved_model.meta_graphs[0] + except message.DecodeError: + try: + meta_graph = meta_graph_pb2.MetaGraphDef() + text_format.Merge(input_data, meta_graph) + except text_format.ParseError: + try: + meta_graph.ParseFromString(input_data) + except message.DecodeError: + try: + graph_def = graph_pb2.GraphDef() + text_format.Merge(input_data, graph_def) + except text_format.ParseError: + try: + graph_def.ParseFromString(input_data) + except message.DecodeError: + raise ValueError("Invalid input file.") + importer.import_graph_def(graph_def, name="") + graph = ops.get_default_graph() + meta_graph = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) + if FLAGS.fetch is not None: + fetch_collection = meta_graph_pb2.CollectionDef() + for fetch in FLAGS.fetch.split(","): + fetch_collection.node_list.value.append(fetch) + meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) + return meta_graph def main(_): @@ -85,15 +96,11 @@ def main(_): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--metagraphdef", + "--input", type=str, default=None, - help="Input .meta MetaGraphDef file path.") - parser.add_argument( - "--graphdef", - type=str, - default=None, - help="Input .pb GraphDef file path.") + help="Input file path. Accept SavedModel, MetaGraphDef, and GraphDef in " + "either binary or text format.") parser.add_argument( "--fetch", type=str, -- 2.7.4