"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
import logging as log
-import networkx as nx
-
from extensions.ops.TensorIterator_ops import TensorIteratorOutput
+from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
--------> Identity -> TensorArrayWrite -> NextIteration
"""
enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+ def run_after(self):
+ from extensions.middle.TensorIteratorInput import SmartInputMatcher
+ return [SmartInputMatcher]
+
+ def run_before(self):
+ from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+ return [TensorIteratorMerge]
@staticmethod
def pattern():
)
@staticmethod
- def replace_pattern(graph: nx.MultiDiGraph, match: dict):
+ def replace_pattern(graph: Graph, match: dict):
log.debug('================== SmartOutputFind ===============')
assert match['WriteEnter_data'].value is not None
if node not in safe_nodes:
nodes_for_remove.append(match[node].id)
graph.remove_nodes_from(nodes_for_remove)
+
+
+class SimpleOutputMatcher(MiddleReplacementPattern):
+ """
+ This pattern match partitioned outputs for TensorIterator in dynamic_rnn loops in TF.
+ The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
+ (data nodes is marked by (data)):
+ TensorArray
+ | |
+ Flow(data) Handle(data)------------------------------
+ | | |
+ v v v
+ Enter -> Merge -> Switch -> Exit -> TensorArrayRead
+ |
+ |
+ |
+ |
+ --------> Identity -> TensorArrayWrite -> NextIteration
+ """
+ enabled = True
+ graph_condition = [lambda graph: graph.graph['is_cyclic']]
+
+ def run_after(self):
+ return [SmartOutputMatcher]
+
+ def run_before(self):
+ from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
+ from extensions.middle.TensorIteratorCondition import LoopConditionMatcher
+ return [TensorIteratorMerge, LoopConditionMatcher]
+
+ @staticmethod
+ def pattern():
+ return dict(
+ nodes=[
+ ('TensorArray', dict(kind='op', op='TensorArrayV3')),
+ ('TensorArray_data', dict(kind='data')),
+ ('TensorArray_flow_data', dict(kind='data')),
+
+ ('TensorArrayWrite', dict(kind='op', op='TensorArrayWriteV3')),
+ ('TensorArrayWrite_data', dict(kind='data')),
+
+ ('NextIteration', dict(kind='op', op='NextIteration')),
+ ('NextIteration_data', dict(kind='data')),
+
+ ('Condition_data', dict(kind='data')),
+
+ ('Identity_2', dict(kind='op', op='Identity')),
+ ('Identity_2_data', dict(kind='data')),
+
+ ('Switch_2', dict(kind='op', op='Switch')),
+ ('Switch_2_data', dict(kind='data')),
+ ('Switch_2_data_exit', dict(kind='data')),
+
+ ('Merge_2', dict(kind='op', op='Merge')),
+ ('Merge_2_data', dict(kind='data')),
+
+ ('Enter_2', dict(kind='op', op='Enter')),
+ ('Enter_2_data', dict(kind='data')),
+
+ ('WriteEnter', dict(kind='op', op='Enter')),
+ ('WriteEnter_data', dict(kind='data')),
+
+ ('Exit', dict(kind='op', op='Exit')),
+ ('Exit_data', dict(kind='data')),
+ #
+ ('TensorArrayRead', dict(op='TensorArrayReadV3')),
+ ('TensorArrayRead_data', dict(kind='data')),
+ ],
+ edges=[
+ ('TensorArray', 'TensorArray_data'),
+ ('TensorArray', 'TensorArray_flow_data'),
+ ('TensorArray_flow_data', 'Enter_2'),
+ ('TensorArray_data', 'WriteEnter'),
+
+
+ ('Enter_2', 'Enter_2_data'),
+ ('Enter_2_data', 'Merge_2'),
+ ('Merge_2', 'Merge_2_data'),
+ ('Merge_2_data', 'Switch_2'),
+ ('Switch_2', 'Switch_2_data'),
+ ('Switch_2', 'Switch_2_data_exit'),
+ ('Switch_2_data', 'Identity_2'),
+ ('Identity_2', 'Identity_2_data'),
+
+ ('Switch_2_data_exit', 'Exit'),
+ ('Exit', 'Exit_data'),
+ ('Exit_data', 'TensorArrayRead'),
+
+ ('WriteEnter', 'WriteEnter_data'),
+ ('WriteEnter_data', 'TensorArrayWrite', {'in': 0}),
+
+ ('Identity_2_data', 'TensorArrayWrite', {'in': 3}),
+ #
+ ('TensorArrayWrite', 'TensorArrayWrite_data'),
+ ('TensorArrayWrite_data', 'NextIteration'),
+ ('Condition_data', 'Switch_2'),
+ #
+ ('TensorArray_data', 'TensorArrayRead'),
+ ('TensorArrayRead', 'TensorArrayRead_data'),
+ ('NextIteration', 'NextIteration_data'),
+ ('NextIteration_data', 'Merge_2'),
+ ],
+ )
+
+ @staticmethod
+ def replace_pattern(graph: Graph, match: dict):
+ log.debug('================== SimpleOutputFind ===============')
+ assert match['WriteEnter_data'].value is not None
+
+ index = match['TensorArrayWrite'].in_node(1)
+ value = match['TensorArrayWrite'].in_node(2)
+
+ # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
+ # condition)
+ output = TensorIteratorOutput(graph, dict(
+ external_port_id=str(match['WriteEnter_data'].value),
+ internal_layer_id=value.id,
+ name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_'
+ ))
+ output.create_node_with_data(inputs=[value, index],
+ data_nodes=[match['TensorArrayRead_data']])
+
+ # Delete useless nodes
+ safe_nodes = ['TensorArrayRead_data', 'Condition_data']
+ nodes_for_remove = []
+ for node in match.keys():
+ if node not in safe_nodes:
+ nodes_for_remove.append(match[node].id)
+ graph.remove_nodes_from(nodes_for_remove)