2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from re import compile, match, findall
22 from mo.front.extractor import update_ie_fields
23 from mo.front.tf.partial_infer.tf import tf_subgraph_infer
24 from mo.graph.graph import Node, merge_edge_props, Graph
25 from mo.utils.graph import nodes_matching_name_pattern, is_connected_component
28 def replace_subgraph_calls(graph: Graph, patterns_string: str):
30 The function replaces sub-graphs defined by the node names with single nodes that are executed using the TensorFlow.
31 The patterns applied independently, so N patterns produce N TensorFlow call nodes.
32 :param graph: networkX graph to operate on.
33 :param patterns_string: comma separated list of node names patterns.
36 patterns = patterns_string.split(',')
37 for pattern in patterns:
38 log.info("Merging nodes using pattern '{}'".format(pattern))
39 matched_nodes = nodes_matching_name_pattern(graph, pattern)
40 if len(matched_nodes) != 0:
41 merge_nodes(graph, matched_nodes)
43 # the function 'find_cycle' raises exception if the cycle is not found
46 except nx.exception.NetworkXNoCycle:
49 log.warning("Graph contains a cycle after merging nodes using pattern '{}'".format(pattern))
51 graph.dump_graph_for_graphviz()
52 log.error('graph contains cycle after applying all merge node patterns')
55 def offload_operations_to_tf(graph: Graph, op_names_patterns: str):
57 The function accepts the list of strings with operation names patterns. The patterns applied independently and nodes
58 matching specific pattern are executed using the TF runtime.
59 :param graph: networkX graph to operate on.
60 :param op_names_patterns: string with regular expressions specifying operation names patterns.
62 patterns = op_names_patterns.split(',')
63 for pattern in patterns:
64 log.info("Running nodes with operation using pattern '{}'".format(pattern))
65 compiled_pattern = compile(pattern)
66 for node_name, attrs in list(graph.nodes(data=True)):
67 if 'pb' in graph.node[node_name]:
68 op = graph.node[node_name]['pb'].op
69 if match(compiled_pattern, op):
70 log.debug("Node '{}' operation matches pattern '{}'".format(node_name, pattern))
71 merge_nodes(graph, [node_name])
74 def internal_output_name_for_node(node_name: str, output_port: int):
75 return node_name + ":" + str(output_port)
78 def add_node_pb_if_not_yet_added(node: Node, mega_node: Node):
79 if node.has_valid('pb') and node.pb.name not in mega_node.pbs.keys():
80 mega_node.pbs[node.pb.name] = node.pb
83 def find_input_port(node: Node, input_desc: list, search_node_name: str, search_node_port: int):
84 if input_desc is None:
85 return len(node.in_nodes())
87 for in_port, tensor_desc in enumerate(input_desc):
88 for node_pattern, node_port in tensor_desc:
89 if findall(node_pattern, search_node_name) and node_port == search_node_port:
91 raise Exception('Did not find input port of the node "{}" with port "{}"'.format(search_node_name,
95 def find_output_port(node: Node, output_desc: list, search_node_name: str, search_node_port: int):
96 if output_desc is None:
97 return len(node.out_nodes())
99 for out_port, (node_pattern, node_port) in enumerate(output_desc):
100 if findall(node_pattern, search_node_name) and node_port == search_node_port:
102 raise Exception('Did not find output port of the node "{}" with port "{}"'.format(search_node_name,
106 def merge_nodes(graph: Graph, nodes_to_merge_names: list, inputs_desc: list = None,
107 outputs_desc: list = None):
109 Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node
110 and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for
111 generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the
112 layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the
113 sub-graph to the node's attribute 'pbs'.
114 :param graph: the graph object to operate on.
115 :param nodes_to_merge_names: list of nodes names that should be merged into a single node.
116 :param inputs_desc: optional list describing input nodes order.
117 :param outputs_desc: optional list describing output nodes order.
119 if not is_connected_component(graph, nodes_to_merge_names):
120 log.warning("The following nodes do not form connected sub-graph: {}".format(nodes_to_merge_names))
121 graph.dump_graph_for_graphviz(nodes_to_dump=nodes_to_merge_names)
123 new_node_name = graph.unique_id("TFSubgraphCall_")
124 log.info("Create new node with name '{}' for nodes '{}'".format(new_node_name, ', '.join(nodes_to_merge_names)))
125 graph.add_node(new_node_name)
126 new_node_attrs = graph.node[new_node_name]
128 new_node_attrs['name'] = new_node_name
129 set_tf_custom_call_node_attrs(new_node_attrs)
130 new_node = Node(graph, new_node_name)
132 added_input_tensors_names = set() # set of tensors that are were added as input to the sub-graph
133 added_new_node_output_tensors = dict() # key - tensor name, value - out port
135 for node_name in nodes_to_merge_names:
136 node = Node(graph, node_name)
137 add_node_pb_if_not_yet_added(node, new_node)
138 # TODO: any improvements?
139 for in_node_name, edge_attrs in Node(graph, node_name).get_inputs():
140 in_node = Node(graph, in_node_name)
142 # internal edges between nodes of the sub-graph
143 if in_node_name in nodes_to_merge_names:
144 add_node_pb_if_not_yet_added(in_node, new_node)
147 # edge outside of sub-graph into sub-graph
148 if in_node_name not in nodes_to_merge_names:
149 # we cannot use the 'in_node_name' as a protobuf operation name here
150 # because the 'in_node_name' could be a sub-graph matched before.
151 input_tensor_name = node.pb.input[edge_attrs['in']]
152 if input_tensor_name not in added_input_tensors_names:
153 graph.add_edge(in_node_name, new_node_name,
155 {'in': find_input_port(new_node, inputs_desc, node_name, edge_attrs['in']),
156 'out': edge_attrs['out'],
157 'internal_input_node_name': input_tensor_name,
158 'original_dst_node_name': node_name,
159 'original_dst_port': edge_attrs['in'],
160 'in_attrs': ['in', 'internal_input_node_name', 'original_dst_node_name',
161 'original_dst_port', 'placeholder_name'],
162 'out_attrs': ['out']},
165 log.debug("Creating edge from outside of sub-graph to inside sub-graph: {} -> {}".format(
166 in_node_name, new_node_name))
167 added_input_tensors_names.add(input_tensor_name)
169 # edge from inside sub-graph to outside sub-graph
170 for out_node_name, edge_attrs in Node(graph, node_name).get_outputs():
171 if out_node_name not in nodes_to_merge_names:
172 log.debug("Creating edge from inside of sub-graph to outside sub-graph: {} -> {}".format(
173 new_node_name, out_node_name))
174 out_name = internal_output_name_for_node(node_name, edge_attrs['out'])
175 if out_name not in added_new_node_output_tensors.keys():
176 added_new_node_output_tensors[out_name] = find_output_port(new_node, outputs_desc, node_name,
178 graph.add_edge(new_node_name, out_node_name,
180 {'in': edge_attrs['in'],
181 'out': added_new_node_output_tensors[out_name],
182 'internal_output_node_name': out_name,
183 'in_attrs': ['in', 'internal_input_node_name'],
184 'out_attrs': ['out', 'internal_output_node_name']},
187 new_node['output_tensors_names'] = [val for val in
188 {v: k for k, v in added_new_node_output_tensors.items()}.values()]
190 # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order
191 new_node['nodes_order'] = [node for node in graph.graph['initial_nodes_order'] if node in new_node['pbs'].keys()]
193 for n in nodes_to_merge_names:
194 if graph.has_node(n): # check if not deleted by another (similar) pattern
196 return Node(graph, new_node_name)
199 def set_tf_custom_call_node_attrs(node_attrs: dict):
200 update_ie_fields(node_attrs)
201 node_attrs['input_nodes_names'] = list()
202 node_attrs['output_tensors_names'] = list()
203 node_attrs['real_input_dims'] = list()
204 node_attrs['pbs'] = dict()
205 node_attrs['type'] = 'TFCustomSubgraphCall'
206 node_attrs['op'] = 'TFCustomSubgraphCall'
207 node_attrs['precision'] = 'FP32' # TODO use real precision derived from the model
208 node_attrs['infer'] = tf_subgraph_infer
209 node_attrs['kind'] = 'op'
212 def tf_find_constant_inputs(node: Node):
214 The function finds constant inputs of the node and nodes with Identity operation.
215 :param node: node to add constants inputs.
216 :return: set of added nodes (Node).
219 for in_node in node.in_nodes().values():
220 if in_node.has_valid('pb'):
221 if in_node['pb'].op == 'Const':
222 added_nodes.add(in_node)
223 if in_node['pb'].op == 'Identity':
224 added_nodes.update(tf_find_constant_inputs(in_node))