2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import deque
18 from copy import deepcopy
22 from extensions.ops.tensor_iterator import TensorIterator
23 from mo.graph.graph import Node, Graph, add_opoutput
24 from mo.middle.replacement import MiddleReplacementPattern
25 from mo.ops.op import Op
26 from mo.ops.reshape import Reshape
27 from mo.utils.graph import sub_graph_between_nodes
29 stop_nodes = ['TensorIteratorInput', 'TensorIteratorOutput', 'TensorIteratorBackEdge', 'TensorIteratorCondition']
32 def op_type(graph, node_name: str):
33 node = Node(graph, node_name)
34 if node.has_valid('kind') and node['kind'] == 'op':
40 def update_inputs(graph, inputs: list, node_name: str):
41 node = Node(graph, node_name)
42 if node.has_valid('kind') and node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
43 if node_name not in inputs:
44 inputs.append(node_name)
47 def reverse_dfs(graph: Graph, node_name: str, stop_nodes: list, inputs: list, visited: set = None):
52 visited.add(node_name)
53 d.appendleft(node_name)
55 cur_node = d.popleft()
56 for in_node_name, _ in graph.in_edges(cur_node):
57 if in_node_name not in visited:
58 if op_type(graph, in_node_name) not in stop_nodes:
59 visited.add(in_node_name)
60 d.append(in_node_name)
62 update_inputs(graph, inputs, in_node_name)
65 def dfs(graph: Graph, node_name: str, stop_nodes: list, visited: set = None):
68 visited.add(node_name)
69 d.appendleft(node_name)
71 cur_node = d.popleft()
72 for _, out_node_name in graph.out_edges(cur_node):
73 if out_node_name not in visited:
74 if op_type(graph, out_node_name) not in stop_nodes:
75 visited.add(out_node_name)
76 d.append(out_node_name)
79 def get_body(graph, inputs, outputs):
80 nodes, extra_inputs = sub_graph_between_nodes(
84 lambda node: node.soft_get('op') == 'TensorIteratorInput'
86 nodes = list(set(nodes) - set(inputs) - set(outputs) - set(extra_inputs))
87 return nodes, extra_inputs
90 class TensorIteratorMerge(MiddleReplacementPattern):
92 graph_condition = [lambda graph: graph.graph['is_cyclic']]
104 ('condition', dict(kind='op', op='TensorIteratorCondition')),
110 def replace_pattern(graph, match: dict):
111 # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op
112 # and make all checks needed for TensorIteator work
113 cond_data = match['condition'].out_node(0)
114 time_data = match['condition'].out_node(1) if len(match['condition'].out_nodes()) > 1 else None
115 name = match['condition'].name
117 assert match['condition'].in_node(0).has_valid('value')
123 for node in cond_data.out_nodes():
124 if node['kind'] == 'op' and node['op'] == 'TensorIteratorBackEdge':
125 back_edges.append(node.id)
126 elif node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
127 inputs.append(node.id)
128 elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
129 outputs.append(node.id)
131 if time_data is not None:
132 for node in time_data.out_nodes():
133 if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput':
134 inputs.append(node.id)
135 elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput':
136 outputs.append(node.id)
138 # something goes wrong here
141 condition = match['condition']
142 tensor_sequence_length = condition.in_node(0)
143 graph.remove_nodes_from([condition.id, cond_data.id, tensor_sequence_length.id])
144 if time_data is not None:
145 graph.remove_nodes_from([time_data.id])
147 body_nodes, extra_inputs = get_body(graph, inputs, outputs)
148 body_nodes = list(set(body_nodes) - set([cond_data]))
150 inputs += extra_inputs
152 assert all([node in graph.nodes() for node in body_nodes])
154 inputs = [Node(graph, node) for node in inputs]
155 outputs = [Node(graph, node) for node in outputs]
156 back_edges = [Node(graph, node) for node in back_edges]
160 'external_data_id': node.in_node(1 if node.has_valid('axis') else 0),
161 'internal_data_id': node.out_node(0),
165 'stride': node.stride,
166 'part_size': node.part_size
167 } for node in inputs]
171 'external_data_id': node.out_node(0),
172 'internal_data_id': node.in_node(1 if node.has_valid('axis') else 0),
176 'stride': node.stride,
177 'part_size': node.part_size
178 } for node in outputs]
182 'from_data_id': node.in_node(1),
183 'to_data_id': node.out_node(0),
184 'init_data_id': node.in_node(0),
185 } for node in back_edges
188 body = Graph(name='body')
189 body.graph = graph.graph
190 body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
192 [(u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes])
194 graph.remove_nodes_from(
195 body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs])
196 internal_id_count = 0
198 for edge in back_edges_data:
199 assert edge['from_data_id'].id in body.nodes()
200 assert edge['to_data_id'].id in body.nodes()
201 assert edge['init_data_id'].id in body.nodes()
202 edge['from_data_id'] = Node(body, edge['from_data_id'].id)
203 edge['to_data_id'] = Node(body, edge['to_data_id'].id)
204 edge['init_data_id'] = Node(body, edge['init_data_id'].id)
205 add_opoutput(body, edge['from_data_id'].id, 0, False)
207 # Assign/reuse ids for the back-edge start; it comes from from_data_id
208 assert len(edge['from_data_id'].in_nodes()) == 1
210 if not edge['from_data_id'].in_node().has_valid('internal_layer_id'):
211 edge['from_data_id'].in_node()['internal_layer_id'] = internal_id_count
212 internal_id_count += 1
213 edge['from_layer'] = edge['from_data_id'].in_node()['internal_layer_id']
216 if 'internal_port_id' not in edge['from_data_id'].in_edge():
217 edge['from_data_id'].in_edge()['internal_port_id'] = internal_id_count
218 internal_id_count += 1
219 edge['from_port'] = edge['from_data_id'].in_edge()['internal_port_id']
221 # Look at all consumers for a data that ends a back-edge
222 # For each such consumer, there will be a separate back-edge (and input)
223 current_real_back_edges = []
224 for _, consumer, key, edge_attrs in body.out_edges(edge['to_data_id'].id, data=True, keys=True):
227 real_edge.update(edge) # all real back_edges have the same back-edge start
229 consumer = Node(body, consumer)
231 if real_edge['to_data_id'].in_node().has_valid('internal_layer_id'):
233 real_edge['to_data_id'].out_node()['internal_layer_id'] = \
234 real_edge['to_data_id'].in_node().internal_layer_id
235 elif not consumer.has_valid('internal_layer_id'):
236 consumer['internal_layer_id'] = internal_id_count
237 internal_id_count += 1
238 real_edge['to_layer'] = consumer['internal_layer_id']
240 assert 'internal_port_id' not in edge_attrs
241 assert len(real_edge['init_data_id'].out_edges()) == 1
242 assert not 'internal_port_id' in real_edge['init_data_id'].out_edge()
243 edge_attrs['internal_port_id'] = internal_id_count
244 internal_id_count += 1
245 real_edge['to_port'] = edge_attrs['internal_port_id']
246 real_edge['consumer'] = consumer
247 real_edge['consumer_key'] = key
249 real_edge['attrs'] = deepcopy(edge_attrs)
250 current_real_back_edges.append(real_edge)
252 # connect initial data node with each consumer providing actual edge attributes
253 body.add_edges_from([
255 real_edge['init_data_id'].id,
256 real_edge['consumer'].id,
257 real_edge['consumer_key'],
259 for real_edge in current_real_back_edges])
261 body.remove_nodes_from([edge['to_data_id'].id, edge['to_data_id'].in_node().id])
262 real_back_edges += current_real_back_edges
264 real_external_inputs = []
266 for ext_inp in external_inputs:
267 assert ext_inp['external_data_id'].id not in body.nodes()
268 assert ext_inp['internal_data_id'].id in body.nodes()
269 ext_inp['internal_data_id'] = Node(body, ext_inp['internal_data_id'].id)
271 if ext_inp['axis'] is not None:
272 # Insert squeezing resize at input port that has partitioning
273 shape = ext_inp['internal_data_id'].shape.copy()
274 assert not ext_inp['internal_data_id'].has_valid('value')
275 new_input_data = Op._create_data_node(body, ext_inp['internal_data_id'].name + '/UnsqueezedInput',
276 dict(shape=np.insert(shape, ext_inp['axis'], 1)))
278 # try to do it dynamically reshapable along one of the axis
279 # it is practically useful to reshape along batch dimension, but here we cannot detect where it is
280 # so, we are guessing based onother transflormaions that it is the major dimension
282 reshape_op = Reshape(body, dict(name=ext_inp['internal_data_id'].name + '/InputSqueeze', dim=dim))
283 reshape_op.create_node_with_data([new_input_data], data_nodes=[ext_inp['internal_data_id']])
284 ext_inp['internal_data_id'] = new_input_data
286 ext_inp['internal_data_id']['is_input'] = True
287 assert len(ext_inp['internal_data_id'].in_nodes()) == 0
288 ext_inp['external_port_id'] = internal_id_count
289 internal_id_count += 1
290 for _, consumer, edge_attrs in body.out_edges(ext_inp['internal_data_id'].id, data=True):
292 real_ext_inp.update(ext_inp)
293 consumer = Node(body, consumer)
294 if not consumer.has_valid('internal_layer_id'):
295 consumer['internal_layer_id'] = internal_id_count
296 internal_id_count += 1
297 if not 'internal_port_id' in edge_attrs:
298 edge_attrs['internal_port_id'] = internal_id_count
299 internal_id_count += 1
300 real_ext_inp['internal_layer_id'] = consumer['internal_layer_id']
301 real_ext_inp['internal_port_id'] = edge_attrs['internal_port_id']
302 real_external_inputs.append(real_ext_inp)
304 for ext_out in external_outputs:
305 assert ext_out['external_data_id'].id not in body.nodes()
306 assert ext_out['internal_data_id'].id in body.nodes()
307 ext_out['internal_data_id'] = Node(body, ext_out['internal_data_id'].id)
309 if ext_out['axis'] is not None:
310 # Insert unsqueezing resize at output port that has partitioning
311 dim = ext_out['internal_data_id'].shape.copy()
312 # trying to make it dynamically reshapable (see related comment above for the first Reshape)
314 assert not ext_out['internal_data_id'].has_valid('value')
315 reshape_op = Reshape(body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze',
316 dim=np.insert(dim, ext_out['axis'], 1)))
317 ext_out['internal_data_id'] = reshape_op.create_node_with_data([ext_out['internal_data_id']])
319 # TODO: add here working with simple outputs
321 add_opoutput(body, ext_out['internal_data_id'].id, 0, False)
322 # assert len(ext_out['internal_data_id'].out_nodes()) == 0
323 assert len(ext_out['internal_data_id'].in_nodes()) == 1
324 if not 'internal_layer_id' in ext_out['internal_data_id'].in_node():
325 ext_out['internal_data_id'].in_node()['internal_layer_id'] = internal_id_count
326 internal_id_count += 1
327 if not 'internal_port_id' in ext_out['internal_data_id'].in_edge():
328 ext_out['internal_data_id'].in_edge()['internal_port_id'] = internal_id_count
329 internal_id_count += 1
330 ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node()['internal_layer_id']
331 ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge()['internal_port_id']
332 ext_out['external_port_id'] = internal_id_count
333 internal_id_count += 1
335 ti_op = TensorIterator(graph, {
336 'name': name + '/TensorIterator',
338 'in_ports_count': len(external_inputs),
339 'out_ports_count': len(external_outputs),
342 {field: external_input[field] for field in
343 ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
345 for external_input in real_external_inputs],
348 {field: external_output[field] for field in
349 ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
351 for external_output in external_outputs],
353 {field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port']}
354 for edge in real_back_edges],
357 ti_outs = ti_op.create_node_with_data(
358 inputs=[inp['external_data_id'] for inp in external_inputs],
359 edge_attrs=[{'external_port_id': inp['external_port_id']} for inp in external_inputs],
360 data_nodes=[out['external_data_id'] for out in external_outputs]
363 if not isinstance(ti_outs, list):
366 for i, out in enumerate(ti_outs):
367 out.in_edge()['external_port_id'] = external_outputs[i]['external_port_id']