3 import tensorflow as tf
12 Returns an ArgumentParser for preserving BCQ information.
14 parser = argparse.ArgumentParser(
15 description=("Command line tool to preserve BCQ information"))
17 # Input and output path.
22 help="Full filepath of the input file.",
28 help="Full filepath of the output file.",
34 def load_graph(frozen_graph_filename):
36 Load graph from frozen pb file
38 with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
39 graph_def = tf.compat.v1.GraphDef()
40 graph_def.ParseFromString(f.read())
41 with tf.Graph().as_default() as graph:
42 tf.import_graph_def(graph_def, name='')
46 def preserve_bcq_info(flags):
48 Generate unique dummy value from -1 to -N.
50 We use negative values to preserve BCQ information because
51 positive values may cause some confusion with real BCQ information values.
56 self.unique_value = -1
59 val = self.unique_value
60 self.unique_value = val - 1
63 unique_value = UniqueValueGen()
65 original_graph_model = load_graph(flags.input_path)
66 original_graph_model_def = original_graph_model.as_graph_def()
68 new_graph = tf.compat.v1.GraphDef()
69 substitution_dict = {}
71 DT_INT32 = None # Just for copying DT_INT32 attribute value
73 for node in original_graph_model_def.node:
74 if node.op == "Const":
75 # Because bcqinfo_do_w_x is BOOL type, we cannot add dummy value at the end.
76 # Therefore we should convert the type to INT32 type.
77 if "/bcqinfo_do_w_x" in node.name:
78 original_tensor = tf.make_ndarray(node.attr["value"].tensor)
79 substitution_dict[node.name] = tf.make_tensor_proto(
80 [int(original_tensor[0]), unique_value.gen()], tf.int32)
82 preserved_bcqinfo_list = ["/bcqinfo_number_of_clusters", "/bcqinfo_size_of_clusters",
83 "/bcqinfo_qbits_of_clusters"]
85 if any(name in node.name for name in preserved_bcqinfo_list):
86 original_tensor = tf.make_ndarray(
87 node.attr["value"].tensor) # variable name change
88 substitution_dict[node.name] = tf.make_tensor_proto(
89 np.append(original_tensor, unique_value.gen()), tf.int32)
90 DT_INT32 = node.attr["dtype"]
92 for node in original_graph_model_def.node:
93 if node.name in substitution_dict:
94 new_node = new_graph.node.add()
96 new_node.name = node.name
97 new_node.attr["dtype"].CopyFrom(DT_INT32)
98 new_node.attr["value"].tensor.CopyFrom(substitution_dict[node.name])
100 new_node = new_graph.node.add()
101 new_node.CopyFrom(node)
103 tf.io.write_graph(new_graph, '.', flags.output_path, False)
108 parser = _get_parser()
109 flags = parser.parse_known_args(args=sys.argv[1:])
111 # Generate a new pb file, which BCQ information is preserved.
112 preserve_bcq_info(flags[0])
115 if __name__ == "__main__":