Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / custom_subgraph_call.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 from re import compile, match, findall
19
20 import networkx as nx
21
22 from mo.front.extractor import update_ie_fields
23 from mo.front.tf.partial_infer.tf import tf_subgraph_infer
24 from mo.graph.graph import Node, merge_edge_props, Graph
25 from mo.utils.graph import nodes_matching_name_pattern, is_connected_component
26
27
28 def replace_subgraph_calls(graph: Graph, patterns_string: str):
29     """
30     The function replaces sub-graphs defined by the node names with single nodes that are executed using the TensorFlow.
31     The patterns applied independently, so N patterns produce N TensorFlow call nodes.
32     :param graph: networkX graph to operate on.
33     :param patterns_string: comma separated list of node names patterns.
34     """
35     cycle_exist = False
36     patterns = patterns_string.split(',')
37     for pattern in patterns:
38         log.info("Merging nodes using pattern '{}'".format(pattern))
39         matched_nodes = nodes_matching_name_pattern(graph, pattern)
40         if len(matched_nodes) != 0:
41             merge_nodes(graph, matched_nodes)
42             try:
43                 # the function 'find_cycle' raises exception if the cycle is not found
44                 nx.find_cycle(graph)
45                 cycle_exist = True
46             except nx.exception.NetworkXNoCycle:
47                 cycle_exist = False
48             if cycle_exist:
49                 log.warning("Graph contains a cycle after merging nodes using pattern '{}'".format(pattern))
50     if cycle_exist:
51         graph.dump_graph_for_graphviz()
52         log.error('graph contains cycle after applying all merge node patterns')
53         
54
55 def offload_operations_to_tf(graph: Graph, op_names_patterns: str):
56     """
57     The function accepts the list of strings with operation names patterns. The patterns applied independently and nodes
58     matching specific pattern are executed using the TF runtime.
59     :param graph: networkX graph to operate on.
60     :param op_names_patterns: string with regular expressions specifying operation names patterns.
61     """
62     patterns = op_names_patterns.split(',')
63     for pattern in patterns:
64         log.info("Running nodes with operation using pattern '{}'".format(pattern))
65         compiled_pattern = compile(pattern)
66         for node_name, attrs in list(graph.nodes(data=True)):
67             if 'pb' in graph.node[node_name]:
68                 op = graph.node[node_name]['pb'].op
69                 if match(compiled_pattern, op):
70                     log.debug("Node '{}' operation matches pattern '{}'".format(node_name, pattern))
71                     merge_nodes(graph, [node_name])
72
73
74 def internal_output_name_for_node(node_name: str, output_port: int):
75     return node_name + ":" + str(output_port)
76
77
78 def add_node_pb_if_not_yet_added(node: Node, mega_node: Node):
79     if node.has_valid('pb') and node.pb.name not in mega_node.pbs.keys():
80         mega_node.pbs[node.pb.name] = node.pb
81
82
83 def find_input_port(node: Node, input_desc: list, search_node_name: str, search_node_port: int):
84     if input_desc is None:
85         return len(node.in_nodes())
86
87     for in_port, tensor_desc in enumerate(input_desc):
88         for node_pattern, node_port in tensor_desc:
89             if findall(node_pattern, search_node_name) and node_port == search_node_port:
90                 return in_port
91     raise Exception('Did not find input port of the node "{}" with port "{}"'.format(search_node_name,
92                                                                                      search_node_port))
93
94
95 def find_output_port(node: Node, output_desc: list, search_node_name: str, search_node_port: int):
96     if output_desc is None:
97         return len(node.out_nodes())
98
99     for out_port, (node_pattern, node_port) in enumerate(output_desc):
100         if findall(node_pattern, search_node_name) and node_port == search_node_port:
101             return out_port
102     raise Exception('Did not find output port of the node "{}" with port "{}"'.format(search_node_name,
103                                                                                       search_node_port))
104
105
106 def merge_nodes(graph: Graph, nodes_to_merge_names: list, inputs_desc: list = None,
107                 outputs_desc: list = None):
108     """
109     Merges nodes specified in the set 'nodes_to_merge_names' into one mega-node, creating new edges between mega-node
110     and inputs/outputs nodes of the mega-node. The added edges contain name of input/output nodes which will be used for
111     generation of placeholders and will be saved to the IR xml so IE plug-in know how to map input/output data for the
112     layer. Also the function adds protobufs of the nodes of the sub-graph and 'Const' ops consumed by nodes in the
113     sub-graph to the node's attribute 'pbs'.
114     :param graph: the graph object to operate on.
115     :param nodes_to_merge_names: list of nodes names that should be merged into a single node.
116     :param inputs_desc: optional list describing input nodes order.
117     :param outputs_desc: optional list describing output nodes order.
118     """
119     if not is_connected_component(graph, nodes_to_merge_names):
120         log.warning("The following nodes do not form connected sub-graph: {}".format(nodes_to_merge_names))
121         graph.dump_graph_for_graphviz(nodes_to_dump=nodes_to_merge_names)
122
123     new_node_name = graph.unique_id("TFSubgraphCall_")
124     log.info("Create new node with name '{}' for nodes '{}'".format(new_node_name, ', '.join(nodes_to_merge_names)))
125     graph.add_node(new_node_name)
126     new_node_attrs = graph.node[new_node_name]
127
128     new_node_attrs['name'] = new_node_name
129     set_tf_custom_call_node_attrs(new_node_attrs)
130     new_node = Node(graph, new_node_name)
131
132     added_input_tensors_names = set()  # set of tensors that are were added as input to the sub-graph
133     added_new_node_output_tensors = dict()  # key - tensor name, value - out port
134
135     for node_name in nodes_to_merge_names:
136         node = Node(graph, node_name)
137         add_node_pb_if_not_yet_added(node, new_node)
138         # TODO: any improvements?
139         for in_node_name, edge_attrs in Node(graph, node_name).get_inputs():
140             in_node = Node(graph, in_node_name)
141
142             # internal edges between nodes of the sub-graph
143             if in_node_name in nodes_to_merge_names:
144                 add_node_pb_if_not_yet_added(in_node, new_node)
145                 continue
146
147             # edge outside of sub-graph into sub-graph
148             if in_node_name not in nodes_to_merge_names:
149                 # we cannot use the 'in_node_name' as a protobuf operation name here
150                 # because the 'in_node_name' could be a sub-graph matched before.
151                 input_tensor_name = node.pb.input[edge_attrs['in']]
152                 if input_tensor_name not in added_input_tensors_names:
153                     graph.add_edge(in_node_name, new_node_name,
154                                    **merge_edge_props(
155                                        {'in': find_input_port(new_node, inputs_desc, node_name, edge_attrs['in']),
156                                         'out': edge_attrs['out'],
157                                         'internal_input_node_name': input_tensor_name,
158                                         'original_dst_node_name': node_name,
159                                         'original_dst_port': edge_attrs['in'],
160                                         'in_attrs': ['in', 'internal_input_node_name', 'original_dst_node_name',
161                                                      'original_dst_port', 'placeholder_name'],
162                                         'out_attrs': ['out']},
163                                        edge_attrs)
164                                    )
165                     log.debug("Creating edge from outside of sub-graph to inside sub-graph: {} -> {}".format(
166                         in_node_name, new_node_name))
167                     added_input_tensors_names.add(input_tensor_name)
168
169         # edge from inside sub-graph to outside sub-graph
170         for out_node_name, edge_attrs in Node(graph, node_name).get_outputs():
171             if out_node_name not in nodes_to_merge_names:
172                 log.debug("Creating edge from inside of sub-graph to outside sub-graph: {} -> {}".format(
173                     new_node_name, out_node_name))
174                 out_name = internal_output_name_for_node(node_name, edge_attrs['out'])
175                 if out_name not in added_new_node_output_tensors.keys():
176                     added_new_node_output_tensors[out_name] = find_output_port(new_node, outputs_desc, node_name,
177                                                                                edge_attrs['out'])
178                 graph.add_edge(new_node_name, out_node_name,
179                                **merge_edge_props(
180                                    {'in': edge_attrs['in'],
181                                     'out': added_new_node_output_tensors[out_name],
182                                     'internal_output_node_name': out_name,
183                                     'in_attrs': ['in', 'internal_input_node_name'],
184                                     'out_attrs': ['out', 'internal_output_node_name']},
185                                    edge_attrs)
186                                )
187         new_node['output_tensors_names'] = [val for val in
188                                             {v: k for k, v in added_new_node_output_tensors.items()}.values()]
189
190     # add nodes using the same order as in initial GraphDef so we can dump them to IR in "correct" order
191     new_node['nodes_order'] = [node for node in graph.graph['initial_nodes_order'] if node in new_node['pbs'].keys()]
192
193     for n in nodes_to_merge_names:
194         if graph.has_node(n):  # check if not deleted by another (similar) pattern
195             graph.remove_node(n)
196     return Node(graph, new_node_name)
197
198
199 def set_tf_custom_call_node_attrs(node_attrs: dict):
200     update_ie_fields(node_attrs)
201     node_attrs['input_nodes_names'] = list()
202     node_attrs['output_tensors_names'] = list()
203     node_attrs['real_input_dims'] = list()
204     node_attrs['pbs'] = dict()
205     node_attrs['type'] = 'TFCustomSubgraphCall'
206     node_attrs['op'] = 'TFCustomSubgraphCall'
207     node_attrs['precision'] = 'FP32'  # TODO use real precision derived from the model
208     node_attrs['infer'] = tf_subgraph_infer
209     node_attrs['kind'] = 'op'
210
211
212 def tf_find_constant_inputs(node: Node):
213     """
214     The function finds constant inputs of the node and nodes with Identity operation.
215     :param node: node to add constants inputs.
216     :return: set of added nodes (Node).
217     """
218     added_nodes = set()
219     for in_node in node.in_nodes().values():
220         if in_node.has_valid('pb'):
221             if in_node['pb'].op == 'Const':
222                 added_nodes.add(in_node)
223             if in_node['pb'].op == 'Identity':
224                 added_nodes.update(tf_find_constant_inputs(in_node))
225     return added_nodes