Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorMerge.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from collections import deque
18 from copy import deepcopy
19
20 import numpy as np
21
22 from extensions.ops.tensor_iterator import TensorIterator
23 from mo.graph.graph import Node, Graph, add_opoutput
24 from mo.middle.replacement import MiddleReplacementPattern
25 from mo.ops.op import Op
26 from mo.ops.reshape import Reshape
27 from mo.utils.graph import sub_graph_between_nodes
28
29 stop_nodes = ['TensorIteratorInput', 'TensorIteratorOutput', 'TensorIteratorBackEdge', 'TensorIteratorCondition']
30
31
32 def op_type(graph, node_name: str):
33     node = Node(graph, node_name)
34     if node.has_valid('kind') and node['kind'] == 'op':
35         return node['op']
36     else:
37         return None
38
39
40 def update_inputs(graph, inputs: list, node_name: str):
41     node = Node(graph, node_name)
42     if node.has_valid('kind') and node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
43         if node_name not in inputs:
44             inputs.append(node_name)
45
46
47 def reverse_dfs(graph: Graph, node_name: str, stop_nodes: list, inputs: list, visited: set = None):
48     d = deque()
49
50     if visited is None:
51         visited = set()
52     visited.add(node_name)
53     d.appendleft(node_name)
54     while len(d) != 0:
55         cur_node = d.popleft()
56         for in_node_name, _ in graph.in_edges(cur_node):
57             if in_node_name not in visited:
58                 if op_type(graph, in_node_name) not in stop_nodes:
59                     visited.add(in_node_name)
60                     d.append(in_node_name)
61                 else:
62                     update_inputs(graph, inputs, in_node_name)
63
64
65 def dfs(graph: Graph, node_name: str, stop_nodes: list, visited: set = None):
66     d = deque()
67
68     visited.add(node_name)
69     d.appendleft(node_name)
70     while len(d) != 0:
71         cur_node = d.popleft()
72         for _, out_node_name in graph.out_edges(cur_node):
73             if out_node_name not in visited:
74                 if op_type(graph, out_node_name) not in stop_nodes:
75                     visited.add(out_node_name)
76                     d.append(out_node_name)
77
78
79 def get_body(graph, inputs, outputs):
80     nodes, extra_inputs = sub_graph_between_nodes(
81         graph,
82         inputs,
83         outputs,
84         lambda node: node.soft_get('op') == 'TensorIteratorInput'
85     )
86     nodes = list(set(nodes) - set(inputs) - set(outputs) - set(extra_inputs))
87     return nodes, extra_inputs
88
89
90 class TensorIteratorMerge(MiddleReplacementPattern):
91     enabled = True
92     graph_condition = [lambda graph: graph.graph['is_cyclic']]
93
94     def run_after(self):
95         return []
96
97     def run_before(self):
98         return []
99
100     @staticmethod
101     def pattern():
102         return dict(
103             nodes=[
104                 ('condition', dict(kind='op', op='TensorIteratorCondition')),
105             ],
106             edges=[],
107         )
108
109     @staticmethod
110     def replace_pattern(graph, match: dict):
111         # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
112         # and make all checks needed for TensorIteator work
113         cond_data = match['condition'].out_node(0)
114         time_data = match['condition'].out_node(1) if len(match['condition'].out_nodes()) > 1 else None
115         name = match['condition'].name
116
117         assert match['condition'].in_node(0).has_valid('value')
118
119         back_edges = []
120         inputs = []
121         outputs = []
122
123         for node in cond_data.out_nodes():
124             if node['kind'] == 'op' and node['op'] == 'TensorIteratorBackEdge':
125                 back_edges.append(node.id)
126             elif node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
127                 inputs.append(node.id)
128             elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
129                 outputs.append(node.id)
130
131         if time_data is not None:
132             for node in time_data.out_nodes():
133                 if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
134                     inputs.append(node.id)
135                 elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
136                     outputs.append(node.id)
137                 else:
138                     # something goes wrong here
139                     assert False
140
141         condition = match['condition']
142         tensor_sequence_length = condition.in_node(0)
143         graph.remove_nodes_from([condition.id, cond_data.id, tensor_sequence_length.id])
144         if time_data is not None:
145             graph.remove_nodes_from([time_data.id])
146
147         body_nodes, extra_inputs = get_body(graph, inputs, outputs)
148         body_nodes = list(set(body_nodes) - set([cond_data]))
149
150         inputs += extra_inputs
151
152         assert all([node in graph.nodes() for node in body_nodes])
153
154         inputs = [Node(graph, node) for node in inputs]
155         outputs = [Node(graph, node) for node in outputs]
156         back_edges = [Node(graph, node) for node in back_edges]
157
158         external_inputs = [
159             {
160                 'external_data_id': node.in_node(1 if node.has_valid('axis') else 0),
161                 'internal_data_id': node.out_node(0),
162                 'axis': node.axis,
163                 'start': node.start,
164                 'end': node.end,
165                 'stride': node.stride,
166                 'part_size': node.part_size
167             } for node in inputs]
168
169         external_outputs = [
170             {
171                 'external_data_id': node.out_node(0),
172                 'internal_data_id': node.in_node(1 if node.has_valid('axis') else 0),
173                 'axis': node.axis,
174                 'start': node.start,
175                 'end': node.end,
176                 'stride': node.stride,
177                 'part_size': node.part_size
178             } for node in outputs]
179
180         back_edges_data = [
181             {
182                 'from_data_id': node.in_node(1),
183                 'to_data_id': node.out_node(0),
184                 'init_data_id': node.in_node(0),
185             } for node in back_edges
186         ]
187
188         body = Graph(name='body')
189         body.graph = graph.graph
190         body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
191         body.add_edges_from(
192             [(u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes])
193
194         graph.remove_nodes_from(
195             body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs])
196         internal_id_count = 0
197         real_back_edges = []
198         for edge in back_edges_data:
199             assert edge['from_data_id'].id in body.nodes()
200             assert edge['to_data_id'].id in body.nodes()
201             assert edge['init_data_id'].id in body.nodes()
202             edge['from_data_id'] = Node(body, edge['from_data_id'].id)
203             edge['to_data_id'] = Node(body, edge['to_data_id'].id)
204             edge['init_data_id'] = Node(body, edge['init_data_id'].id)
205             add_opoutput(body, edge['from_data_id'].id, 0, False)
206
207             # Assign/reuse ids for the back-edge start; it comes from from_data_id
208             assert len(edge['from_data_id'].in_nodes()) == 1
209             # layer id
210             if not edge['from_data_id'].in_node().has_valid('internal_layer_id'):
211                 edge['from_data_id'].in_node()['internal_layer_id'] = internal_id_count
212                 internal_id_count += 1
213             edge['from_layer'] = edge['from_data_id'].in_node()['internal_layer_id']
214
215             # port id
216             if 'internal_port_id' not in edge['from_data_id'].in_edge():
217                 edge['from_data_id'].in_edge()['internal_port_id'] = internal_id_count
218                 internal_id_count += 1
219             edge['from_port'] = edge['from_data_id'].in_edge()['internal_port_id']
220
221             # Look at all consumers for a data that ends a back-edge
222             # For each such consumer, there will be a separate back-edge (and input)
223             current_real_back_edges = []
224             for _, consumer, key, edge_attrs in body.out_edges(edge['to_data_id'].id, data=True, keys=True):
225
226                 real_edge = {}
227                 real_edge.update(edge)  # all real back_edges have the same back-edge start
228
229                 consumer = Node(body, consumer)
230
231                 if real_edge['to_data_id'].in_node().has_valid('internal_layer_id'):
232                     assert False
233                     real_edge['to_data_id'].out_node()['internal_layer_id'] = \
234                         real_edge['to_data_id'].in_node().internal_layer_id
235                 elif not consumer.has_valid('internal_layer_id'):
236                     consumer['internal_layer_id'] = internal_id_count
237                     internal_id_count += 1
238                 real_edge['to_layer'] = consumer['internal_layer_id']
239
240                 assert 'internal_port_id' not in edge_attrs
241                 assert len(real_edge['init_data_id'].out_edges()) == 1
242                 assert not 'internal_port_id' in real_edge['init_data_id'].out_edge()
243                 edge_attrs['internal_port_id'] = internal_id_count
244                 internal_id_count += 1
245                 real_edge['to_port'] = edge_attrs['internal_port_id']
246                 real_edge['consumer'] = consumer
247                 real_edge['consumer_key'] = key
248
249                 real_edge['attrs'] = deepcopy(edge_attrs)
250                 current_real_back_edges.append(real_edge)
251
252             # connect initial data node with each consumer providing actual edge attributes
253             body.add_edges_from([
254                 (
255                     real_edge['init_data_id'].id,
256                     real_edge['consumer'].id,
257                     real_edge['consumer_key'],
258                     real_edge['attrs'])
259                 for real_edge in current_real_back_edges])
260
261             body.remove_nodes_from([edge['to_data_id'].id, edge['to_data_id'].in_node().id])
262             real_back_edges += current_real_back_edges
263
264         real_external_inputs = []
265
266         for ext_inp in external_inputs:
267             assert ext_inp['external_data_id'].id not in body.nodes()
268             assert ext_inp['internal_data_id'].id in body.nodes()
269             ext_inp['internal_data_id'] = Node(body, ext_inp['internal_data_id'].id)
270
271             if ext_inp['axis'] is not None:
272                 # Insert squeezing resize at input port that has partitioning
273                 shape = ext_inp['internal_data_id'].shape.copy()
274                 assert not ext_inp['internal_data_id'].has_valid('value')
275                 new_input_data = Op._create_data_node(body, ext_inp['internal_data_id'].name + '/UnsqueezedInput',
276                                                       dict(shape=np.insert(shape, ext_inp['axis'], 1)))
277                 dim = shape.copy()
278                 # try to do it dynamically reshapable along one of the axis
279                 # it is practically useful to reshape along batch dimension, but here we cannot detect where it is
280                 # so, we are guessing based onother transflormaions that it is the major dimension
281                 dim[0] = -1
282                 reshape_op = Reshape(body, dict(name=ext_inp['internal_data_id'].name + '/InputSqueeze', dim=dim))
283                 reshape_op.create_node_with_data([new_input_data], data_nodes=[ext_inp['internal_data_id']])
284                 ext_inp['internal_data_id'] = new_input_data
285
286             ext_inp['internal_data_id']['is_input'] = True
287             assert len(ext_inp['internal_data_id'].in_nodes()) == 0
288             ext_inp['external_port_id'] = internal_id_count
289             internal_id_count += 1
290             for _, consumer, edge_attrs in body.out_edges(ext_inp['internal_data_id'].id, data=True):
291                 real_ext_inp = {}
292                 real_ext_inp.update(ext_inp)
293                 consumer = Node(body, consumer)
294                 if not consumer.has_valid('internal_layer_id'):
295                     consumer['internal_layer_id'] = internal_id_count
296                     internal_id_count += 1
297                 if not 'internal_port_id' in edge_attrs:
298                     edge_attrs['internal_port_id'] = internal_id_count
299                     internal_id_count += 1
300                 real_ext_inp['internal_layer_id'] = consumer['internal_layer_id']
301                 real_ext_inp['internal_port_id'] = edge_attrs['internal_port_id']
302                 real_external_inputs.append(real_ext_inp)
303
304         for ext_out in external_outputs:
305             assert ext_out['external_data_id'].id not in body.nodes()
306             assert ext_out['internal_data_id'].id in body.nodes()
307             ext_out['internal_data_id'] = Node(body, ext_out['internal_data_id'].id)
308
309             if ext_out['axis'] is not None:
310                 # Insert unsqueezing resize at output port that has partitioning
311                 dim = ext_out['internal_data_id'].shape.copy()
312                 # trying to make it dynamically reshapable (see related comment above for the first Reshape)
313                 dim[0] = -1
314                 assert not ext_out['internal_data_id'].has_valid('value')
315                 reshape_op = Reshape(body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze',
316                                                 dim=np.insert(dim, ext_out['axis'], 1)))
317                 ext_out['internal_data_id'] = reshape_op.create_node_with_data([ext_out['internal_data_id']])
318
319             # TODO: add here working with simple outputs
320
321             add_opoutput(body, ext_out['internal_data_id'].id, 0, False)
322             # assert len(ext_out['internal_data_id'].out_nodes()) == 0
323             assert len(ext_out['internal_data_id'].in_nodes()) == 1
324             if not 'internal_layer_id' in ext_out['internal_data_id'].in_node():
325                 ext_out['internal_data_id'].in_node()['internal_layer_id'] = internal_id_count
326                 internal_id_count += 1
327             if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
328                 ext_out['internal_data_id'].in_edge()['internal_port_id'] = internal_id_count
329                 internal_id_count += 1
330             ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node()['internal_layer_id']
331             ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge()['internal_port_id']
332             ext_out['external_port_id'] = internal_id_count
333             internal_id_count += 1
334
335         ti_op = TensorIterator(graph, {
336             'name': name + '/TensorIterator',
337             'body': body,
338             'in_ports_count': len(external_inputs),
339             'out_ports_count': len(external_outputs),
340
341             'input_port_map': [
342                 {field: external_input[field] for field in
343                  ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
344                   'end']}
345                 for external_input in real_external_inputs],
346
347             'output_port_map': [
348                 {field: external_output[field] for field in
349                  ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
350                   'end']}
351                 for external_output in external_outputs],
352             'back_edges': [
353                 {field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port']}
354                 for edge in real_back_edges],
355         })
356
357         ti_outs = ti_op.create_node_with_data(
358             inputs=[inp['external_data_id'] for inp in external_inputs],
359             edge_attrs=[{'external_port_id': inp['external_port_id']} for inp in external_inputs],
360             data_nodes=[out['external_data_id'] for out in external_outputs]
361         )
362
363         if not isinstance(ti_outs, list):
364             ti_outs = [ti_outs]
365
366         for i, out in enumerate(ti_outs):
367             out.in_edge()['external_port_id'] = external_outputs[i]['external_port_id']