3 import tensorflow as tf
11 Returns an ArgumentParser for generating output_arrays.
13 parser = argparse.ArgumentParser(
14 description=("Command line tool to generated output_arrays of BCQ nodes"))
16 # Input and output path.
21 help="Full filepath of the input file.",
27 help="Full filepath of the output file.",
33 def load_graph(frozen_graph_filename):
35 Load graph from frozen pb file
37 with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
38 graph_def = tf.compat.v1.GraphDef()
39 graph_def.ParseFromString(f.read())
40 with tf.Graph().as_default() as graph:
41 tf.import_graph_def(graph_def, name='')
48 elif dtype == "int64":
50 elif dtype == "float32":
55 raise Exception("Not supported dtype")
58 def print_output_arrays(flags):
59 graph_model = load_graph(flags.input_path)
60 graph_model_def = graph_model.as_graph_def()
61 ops = graph_model.get_operations()
63 output_names = [op.outputs[0].name for op in ops
64 if op.type == "Const" and "bcqinfo_" in op.outputs[0].name]
67 for output_name in output_names:
70 colon_index = output_name.find(":")
72 output_arrays += output_name
74 output_arrays += output_name[:colon_index]
76 f = open(flags.output_path, 'w')
77 f.write(output_arrays)
83 parser = _get_parser()
84 flags = parser.parse_known_args(args=sys.argv[1:])
86 print_output_arrays(flags[0])
89 if __name__ == "__main__":