import argparse
import sys
+from google.protobuf import message
from google.protobuf import text_format
from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
def get_metagraph():
"""Constructs and returns a MetaGraphDef from the input file."""
- if FLAGS.metagraphdef:
- with gfile.GFile(FLAGS.metagraphdef) as meta_file:
- metagraph = meta_graph_pb2.MetaGraphDef()
- if FLAGS.metagraphdef.endswith(".pbtxt"):
- text_format.Merge(meta_file.read(), metagraph)
- else:
- metagraph.ParseFromString(meta_file.read())
- if FLAGS.fetch is not None:
- fetch_collection = meta_graph_pb2.CollectionDef()
- for fetch in FLAGS.fetch.split(","):
- fetch_collection.node_list.value.append(fetch)
- metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
- else:
- with gfile.GFile(FLAGS.graphdef) as graph_file:
- graph_def = graph_pb2.GraphDef()
- if FLAGS.graphdef.endswith(".pbtxt"):
- text_format.Merge(graph_file.read(), graph_def)
- else:
- graph_def.ParseFromString(graph_file.read())
- importer.import_graph_def(graph_def, name="")
- graph = ops.get_default_graph()
- for fetch in FLAGS.fetch.split(","):
- fetch_op = graph.get_operation_by_name(fetch)
- graph.add_to_collection("train_op", fetch_op)
- metagraph = saver.export_meta_graph(
- graph_def=graph.as_graph_def(), graph=graph)
- return metagraph
+ with gfile.GFile(FLAGS.input) as input_file:
+ input_data = input_file.read()
+ try:
+ saved_model = saved_model_pb2.SavedModel()
+ text_format.Merge(input_data, saved_model)
+ meta_graph = saved_model.meta_graphs[0]
+ except text_format.ParseError:
+ try:
+ saved_model.ParseFromString(input_data)
+ meta_graph = saved_model.meta_graphs[0]
+ except message.DecodeError:
+ try:
+ meta_graph = meta_graph_pb2.MetaGraphDef()
+ text_format.Merge(input_data, meta_graph)
+ except text_format.ParseError:
+ try:
+ meta_graph.ParseFromString(input_data)
+ except message.DecodeError:
+ try:
+ graph_def = graph_pb2.GraphDef()
+ text_format.Merge(input_data, graph_def)
+ except text_format.ParseError:
+ try:
+ graph_def.ParseFromString(input_data)
+ except message.DecodeError:
+ raise ValueError("Invalid input file.")
+ importer.import_graph_def(graph_def, name="")
+ graph = ops.get_default_graph()
+ meta_graph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
+ if FLAGS.fetch is not None:
+ fetch_collection = meta_graph_pb2.CollectionDef()
+ for fetch in FLAGS.fetch.split(","):
+ fetch_collection.node_list.value.append(fetch)
+ meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
+ return meta_graph
def main(_):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "--metagraphdef",
+ "--input",
type=str,
default=None,
- help="Input .meta MetaGraphDef file path.")
- parser.add_argument(
- "--graphdef",
- type=str,
- default=None,
- help="Input .pb GraphDef file path.")
+ help="Input file path. Accept SavedModel, MetaGraphDef, and GraphDef in "
+ "either binary or text format.")
parser.add_argument(
"--fetch",
type=str,