Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorBackEdge.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from extensions.ops.TensorIterator_ops import TensorIteratorBackEdge, TensorIteratorOutput
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
22
23
24 class BackEdgesMatching(MiddleReplacementPattern):
25     """
26     This pattern are needed for matching back edges in while loops in TF graphs.
27     Back edge is a chain of nodes in while loop that iterate one variable in graph over loop steps. It consist of
28     nodes:
29                         Exit (optional)
30                             ^
31                             |
32     Enter () -> Merge -> Switch -> Identity -> SOME OPERATIONS -> NextIteration ->
33                 ^                                                                 |
34                 |                                                                 |
35                 ------------------------------------------------------------------
36     The structure of pattern without Data nodes between ops (every node is named as op attribute of this node):
37                 Data--
38                       |
39         NextIteration -> Merge--
40                                 |
41                                 ->Switch (out=1) -> Identity
42                                 |
43        TensorIteratorCondition--
44     """
45     enabled = True
46     graph_condition = [lambda graph: graph.graph['is_cyclic']]
47
48     def run_after(self):
49         from extensions.middle.TensorIteratorCondition import SimpleConditionMatcher
50         return [SimpleConditionMatcher]
51
52     def run_before(self):
53         from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
54         return [TensorIteratorMerge]
55
56     @staticmethod
57     def pattern():
58         return dict(
59             nodes=[
60                 ('Enter_1_data', dict(kind='data')),
61
62                 ('Merge_1', dict(kind='op', op='Merge')),
63                 ('Merge_1_data', dict(kind='data')),
64
65                 ('Switch_1', dict(kind='op', op='Switch')),
66                 ('Switch_1_data', dict(kind='data')),
67
68                 ('Identity_1', dict(kind='op', op='Identity')),
69                 ('Identity_1_data', dict(kind='data')),
70
71                 ('NextIteration', dict(kind='op', op='NextIteration')),
72                 ('NextIteration_data', dict(kind='data')),
73
74                 ('condition', dict(kind='op', op='TensorIteratorCondition')),
75                 ('condition_cond_data', dict(kind='data')),
76             ],
77             edges=[
78                 ('Enter_1_data', 'Merge_1'),
79                 ('Merge_1', 'Merge_1_data'),
80
81                 ('Merge_1_data', 'Switch_1'),
82                 ('Switch_1', 'Switch_1_data', {'out': 1}),
83                 ('Switch_1_data', 'Identity_1'),
84                 ('Identity_1', 'Identity_1_data'),
85
86                 ('NextIteration', 'NextIteration_data'),
87                 ('NextIteration_data', 'Merge_1'),
88
89                 ('condition', 'condition_cond_data'),
90                 ('condition_cond_data', 'Switch_1'),
91             ]
92         )
93
94     def replace_pattern(self, graph: Graph, match: dict):
95         log.debug('================== BackEdgeFind ===============')
96
97         nodes_for_remove = []
98         from_body_data = match['NextIteration'].in_node()
99
100         # If Exit path is exist -> create TensorIteratorOutput for this
101         if 0 in match['Switch_1'].out_nodes():
102             Exit = match['Switch_1'].out_node(0)
103             output_data = Exit.out_node(0)
104
105             nodes_for_remove.append(match['Switch_1'].out_node(0).id)
106             nodes_for_remove.append(Exit.id)
107
108             # Creating TensorIteratorOutput without partition
109             output = TensorIteratorOutput(graph, dict(external_port_id=None,
110                                                       internal_layer_id=None, \
111                                                       name=Exit.name + '/TensorIteratorOutput_'
112                                                       ))
113             output.create_node_with_data(inputs=[from_body_data, match['condition_cond_data']],
114                                          data_nodes=[output_data])
115
116         assert match['NextIteration_data'].id != match['Enter_1_data'].id
117         backedge = TensorIteratorBackEdge(graph, dict(name=match['Identity_1'].name + '/TensorIteratorBackEdge_'))
118         backedge.create_node_with_data(inputs=[match['Enter_1_data'], from_body_data, match['condition_cond_data']],
119                                        data_nodes=[match['Identity_1_data']])
120
121         # Delete useless nodes
122         safe_nodes = ['Identity_1_data', 'condition', 'condition_cond_data', 'Enter_1_data']
123         for node in match.keys():
124             if node not in safe_nodes:
125                 nodes_for_remove.append(match[node].id)
126         graph.remove_nodes_from(nodes_for_remove)