"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
-
from collections import deque
from copy import deepcopy
-import networkx as nx
import numpy as np
-from mo.graph.graph import Node
-from mo.utils.graph import sub_graph_between_nodes
-from mo.middle.replacement import MiddleReplacementPattern
from extensions.ops.tensor_iterator import TensorIterator
+from mo.graph.graph import Node, Graph, add_opoutput
+from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.op import Op
from mo.ops.reshape import Reshape
+from mo.utils.graph import sub_graph_between_nodes
stop_nodes = ['TensorIteratorInput', 'TensorIteratorOutput', 'TensorIteratorBackEdge', 'TensorIteratorCondition']
+
def op_type(graph, node_name: str):
node = Node(graph, node_name)
if node.has_valid('kind') and node['kind'] == 'op':
inputs.append(node_name)
-def reverse_dfs(graph: nx.MultiDiGraph, node_name: str, stop_nodes: list, inputs: list, visited: set = None):
+def reverse_dfs(graph: Graph, node_name: str, stop_nodes: list, inputs: list, visited: set = None):
d = deque()
if visited is None:
else:
update_inputs(graph, inputs, in_node_name)
-def dfs(graph: nx.MultiDiGraph, node_name: str, stop_nodes: list, visited: set = None):
+
+def dfs(graph: Graph, node_name: str, stop_nodes: list, visited: set = None):
d = deque()
visited.add(node_name)
visited.add(out_node_name)
d.append(out_node_name)
+
def get_body(graph, inputs, outputs):
nodes, extra_inputs = sub_graph_between_nodes(
graph,
inputs,
outputs,
- lambda node: node.soft_get('op') == 'TensorIteratorInput'
+ lambda node: node.soft_get('op') == 'TensorIteratorInput'
)
nodes = list(set(nodes) - set(inputs) - set(outputs) - set(extra_inputs))
return nodes, extra_inputs
class TensorIteratorMerge(MiddleReplacementPattern):
+ enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+ def run_after(self):
+ return []
+
+ def run_before(self):
+ return []
+
@staticmethod
def pattern():
return dict(
inputs = [Node(graph, node) for node in inputs]
outputs = [Node(graph, node) for node in outputs]
back_edges = [Node(graph, node) for node in back_edges]
-
+
external_inputs = [
{
'external_data_id': node.in_node(1 if node.has_valid('axis') else 0),
'part_size': node.part_size
} for node in inputs]
-
external_outputs = [
{
'external_data_id': node.out_node(0),
'part_size': node.part_size
} for node in outputs]
-
back_edges_data = [
{
'from_data_id': node.in_node(1),
} for node in back_edges
]
- body = nx.MultiDiGraph(name='body')
- body.graph['layout'] = graph.graph['layout']
+ body = Graph(name='body')
+ body.graph = graph.graph
body.add_nodes_from([(node, graph.node[node]) for node in body_nodes])
- body.add_edges_from([(u,v,k,d)for u,v,k,d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes])
+ body.add_edges_from(
+ [(u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes])
- graph.remove_nodes_from(body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs])
+ graph.remove_nodes_from(
+ body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs])
internal_id_count = 0
real_back_edges = []
for edge in back_edges_data:
edge['from_data_id'] = Node(body, edge['from_data_id'].id)
edge['to_data_id'] = Node(body, edge['to_data_id'].id)
edge['init_data_id'] = Node(body, edge['init_data_id'].id)
- edge['from_data_id']['is_output'] = True
+ add_opoutput(body, edge['from_data_id'].id, 0, False)
# Assign/reuse ids for the back-edge start; it comes from from_data_id
assert len(edge['from_data_id'].in_nodes()) == 1
for _, consumer, key, edge_attrs in body.out_edges(edge['to_data_id'].id, data=True, keys=True):
real_edge = {}
- real_edge.update(edge) # all real back_edges have the same back-edge start
+ real_edge.update(edge) # all real back_edges have the same back-edge start
consumer = Node(body, consumer)
if real_edge['to_data_id'].in_node().has_valid('internal_layer_id'):
assert False
- real_edge['to_data_id'].out_node()['internal_layer_id'] = real_edge['to_data_id'].in_node().internal_layer_id
+ real_edge['to_data_id'].out_node()['internal_layer_id'] = \
+ real_edge['to_data_id'].in_node().internal_layer_id
elif not consumer.has_valid('internal_layer_id'):
consumer['internal_layer_id'] = internal_id_count
internal_id_count += 1
real_edge['consumer'].id,
real_edge['consumer_key'],
real_edge['attrs'])
- for real_edge in current_real_back_edges])
+ for real_edge in current_real_back_edges])
body.remove_nodes_from([edge['to_data_id'].id, edge['to_data_id'].in_node().id])
real_back_edges += current_real_back_edges
# Insert squeezing resize at input port that has partitioning
shape = ext_inp['internal_data_id'].shape.copy()
assert not ext_inp['internal_data_id'].has_valid('value')
- new_input_data = Op._create_data_node(body, ext_inp['internal_data_id'].name + '/UnsqueezedInput', dict(shape=np.insert(shape, ext_inp['axis'], 1)))
+ new_input_data = Op._create_data_node(body, ext_inp['internal_data_id'].name + '/UnsqueezedInput',
+ dict(shape=np.insert(shape, ext_inp['axis'], 1)))
dim = shape.copy()
# try to do it dynamically reshapable along one of the axis
# it is practically useful to reshape along batch dimension, but here we cannot detect where it is
# trying to make it dynamically reshapable (see related comment above for the first Reshape)
dim[0] = -1
assert not ext_out['internal_data_id'].has_valid('value')
- reshape_op = Reshape(body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze', dim=np.insert(dim, ext_out['axis'], 1)))
+ reshape_op = Reshape(body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze',
+ dim=np.insert(dim, ext_out['axis'], 1)))
ext_out['internal_data_id'] = reshape_op.create_node_with_data([ext_out['internal_data_id']])
# TODO: add here working with simple outputs
- ext_out['internal_data_id']['is_output'] = True
- #assert len(ext_out['internal_data_id'].out_nodes()) == 0
+ add_opoutput(body, ext_out['internal_data_id'].id, 0, False)
+ # assert len(ext_out['internal_data_id'].out_nodes()) == 0
assert len(ext_out['internal_data_id'].in_nodes()) == 1
if not 'internal_layer_id' in ext_out['internal_data_id'].in_node():
ext_out['internal_data_id'].in_node()['internal_layer_id'] = internal_id_count
ti_op = TensorIterator(graph, {
'name': name + '/TensorIterator',
'body': body,
+ 'in_ports_count': len(external_inputs),
+ 'out_ports_count': len(external_outputs),
'input_port_map': [
- {field: external_input[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end']}
+ {field: external_input[field] for field in
+ ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
+ 'end']}
for external_input in real_external_inputs],
'output_port_map': [
- {field: external_output[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end']}
+ {field: external_output[field] for field in
+ ['external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start',
+ 'end']}
for external_output in external_outputs],
'back_edges': [
- {field: edge[field] for field in [ 'from_layer', 'from_port', 'to_layer', 'to_port']}
+ {field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port']}
for edge in real_back_edges],
})
for i, out in enumerate(ti_outs):
out.in_edge()['external_port_id'] = external_outputs[i]['external_port_id']
-
-
-
- # Create TI operation