Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorOutput.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from extensions.ops.TensorIterator_ops import TensorIteratorOutput
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
22
23
24 class SmartOutputMatcher(MiddleReplacementPattern):
25     """
26     This pattern match partitioned outputs for TensorIterator in dynamic_rnn loops in TF.
27     The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
28     (data nodes is marked by (data)):
29         TensorArray
30         |         |                                                                           Condition(data)
31     Flow(data)  Handle(data)---------------------------------------------------------------     |
32             |    |                                       |                                 |    |
33             v    v                                       v                                 v    v
34             Enter  ->  Merge -> Switch -> Exit -> TensorArraySize -> Range(0;1) -> TensorArrayGather
35                                     |       |                                            ^
36                                     |       |                                            |
37                                     |        ---------------------------------------------
38                                     |
39                                     --------> Identity -> TensorArrayWrite -> NextIteration
40     """
41     enabled = True
42     graph_condition = [lambda graph: graph.graph['is_cyclic']]
43
44     def run_after(self):
45         from extensions.middle.TensorIteratorInput import SmartInputMatcher
46         return [SmartInputMatcher]
47
48     def run_before(self):
49         from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
50         return [TensorIteratorMerge]
51
52     @staticmethod
53     def pattern():
54         return dict(
55             nodes=[
56                 ('TensorArray', dict(kind='op', op='TensorArrayV3')),
57                 ('TensorArray_data', dict(kind='data')),
58                 ('TensorArray_flow_data', dict(kind='data')),
59                 ('TensorArrayGather', dict(kind='op', op='TensorArrayGatherV3')),
60                 ('TensorArrayGather_data', dict(kind='data')),
61                 ('range', dict(kind='op', op='Range')),
62                 ('range_data', dict(kind='data')),
63                 ('size', dict(kind='op', op='TensorArraySizeV3')),
64                 ('size_data', dict(kind='data')),
65                 ('start', dict(kind='op', op='Const')),
66                 ('start_data', dict(kind='data')),
67                 ('delta', dict(kind='op', op='Const')),
68                 ('delta_data', dict(kind='data')),
69                 ('TensorArrayWrite', dict(kind='op', op='TensorArrayWriteV3')),
70                 ('TensorArrayWrite_data', dict(kind='data')),
71                 ('NextIteration', dict(kind='op', op='NextIteration')),
72                 ('Condition_data', dict(kind='data')),
73                 ('Identity_2_data', dict(kind='data')),
74                 ('Identity_2', dict(kind='op', op='Identity')),
75                 ('Switch_2', dict(kind='op', op='Switch')),
76                 ('Switch_2_data', dict(kind='data')),
77                 ('Switch_2_data_exit', dict(kind='data')),
78                 ('Merge_2', dict(kind='op', op='Merge')),
79                 ('Merge_2_data', dict(kind='data')),
80                 ('Enter_2', dict(kind='op', op='Enter')),
81                 ('Enter_2_data', dict(kind='data')),
82                 ('WriteEnter', dict(kind='op', op='Enter')),
83                 ('WriteEnter_data', dict(kind='data')),
84                 ('Exit', dict(kind='op', op='Exit')),
85                 ('Exit_data', dict(kind='data')),
86             ],
87             edges=[
88                 ('TensorArray', 'TensorArray_data'),
89                 ('TensorArray', 'TensorArray_flow_data'),
90                 ('TensorArray_flow_data', 'Enter_2'),
91                 ('TensorArray_data', 'WriteEnter'),
92                 ('TensorArray_data', 'TensorArrayGather'),
93                 ('TensorArrayGather', 'TensorArrayGather_data'),
94                 ('TensorArray_data', 'size'),
95
96                 ('size', 'size_data'),
97                 ('start', 'start_data'),
98                 ('delta', 'delta_data'),
99
100                 ('size_data', 'range', {'in': 1}),
101                 ('start_data', 'range', {'in': 0}),
102                 ('delta_data', 'range', {'in': 2}),
103                 ('range', 'range_data'),
104                 ('range_data', 'TensorArrayGather'),
105
106                 ('Enter_2', 'Enter_2_data'),
107                 ('Enter_2_data', 'Merge_2'),
108                 ('Merge_2', 'Merge_2_data'),
109                 ('Merge_2_data', 'Switch_2'),
110                 ('Switch_2', 'Switch_2_data'),
111                 ('Switch_2', 'Switch_2_data_exit'),
112                 ('Switch_2_data', 'Identity_2'),
113                 ('Identity_2', 'Identity_2_data'),
114
115                 ('Switch_2_data_exit', 'Exit'),
116                 ('Exit', 'Exit_data'),
117                 ('Exit_data', 'size'),
118                 ('Exit_data', 'TensorArrayGather'),
119
120                 ('WriteEnter', 'WriteEnter_data'),
121                 ('WriteEnter_data', 'TensorArrayWrite', {'in': 0}),
122
123                 ('Identity_2_data', 'TensorArrayWrite', {'in': 3}),
124
125                 ('TensorArrayWrite', 'TensorArrayWrite_data'),
126                 ('TensorArrayWrite_data', 'NextIteration'),
127                 ('Condition_data', 'Switch_2'),
128             ],
129         )
130
131     @staticmethod
132     def replace_pattern(graph: Graph, match: dict):
133         log.debug('================== SmartOutputFind ===============')
134
135         assert match['WriteEnter_data'].value is not None
136         assert match['start_data']['value'] == 0 and match['delta_data']['value'] == 1
137
138         ta_size = match['TensorArray'].in_node()
139
140         index = match['TensorArrayWrite'].in_node(1)
141         value = match['TensorArrayWrite'].in_node(2)
142
143         # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
144         # condition)
145         output = TensorIteratorOutput(graph, dict(axis=0, start=None, stride=None, part_size=None,
146                                                   external_port_id=str(match['WriteEnter_data'].value),
147                                                   internal_layer_id=value.id,
148                                                   name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_'
149                                                   ))
150         output.create_node_with_data(inputs=[ta_size, value, index],
151                                      data_nodes=[match['TensorArrayGather_data']])
152
153         # Delete useless nodes
154         safe_nodes = ['TensorArrayGather_data', 'Condition_data']
155         nodes_for_remove = []
156         for node in match.keys():
157             if node not in safe_nodes:
158                 nodes_for_remove.append(match[node].id)
159         graph.remove_nodes_from(nodes_for_remove)
160
161
162 class SimpleOutputMatcher(MiddleReplacementPattern):
163     """
164     This pattern match partitioned outputs for TensorIterator in dynamic_rnn loops in TF.
165     The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
166     (data nodes is marked by (data)):
167         TensorArray
168         |         |
169     Flow(data)  Handle(data)------------------------------
170             |    |                                       |
171             v    v                                       v
172             Enter  ->  Merge -> Switch -> Exit -> TensorArrayRead
173                                     |
174                                     |
175                                     |
176                                     |
177                                     --------> Identity -> TensorArrayWrite -> NextIteration
178     """
179     enabled = True
180     graph_condition = [lambda graph: graph.graph['is_cyclic']]
181
182     def run_after(self):
183         return [SmartOutputMatcher]
184
185     def run_before(self):
186         from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
187         from extensions.middle.TensorIteratorCondition import LoopConditionMatcher
188         return [TensorIteratorMerge, LoopConditionMatcher]
189
190     @staticmethod
191     def pattern():
192         return dict(
193             nodes=[
194                 ('TensorArray', dict(kind='op', op='TensorArrayV3')),
195                 ('TensorArray_data', dict(kind='data')),
196                 ('TensorArray_flow_data', dict(kind='data')),
197
198                 ('TensorArrayWrite', dict(kind='op', op='TensorArrayWriteV3')),
199                 ('TensorArrayWrite_data', dict(kind='data')),
200
201                 ('NextIteration', dict(kind='op', op='NextIteration')),
202                 ('NextIteration_data', dict(kind='data')),
203
204                 ('Condition_data', dict(kind='data')),
205
206                 ('Identity_2', dict(kind='op', op='Identity')),
207                 ('Identity_2_data', dict(kind='data')),
208
209                 ('Switch_2', dict(kind='op', op='Switch')),
210                 ('Switch_2_data', dict(kind='data')),
211                 ('Switch_2_data_exit', dict(kind='data')),
212
213                 ('Merge_2', dict(kind='op', op='Merge')),
214                 ('Merge_2_data', dict(kind='data')),
215
216                 ('Enter_2', dict(kind='op', op='Enter')),
217                 ('Enter_2_data', dict(kind='data')),
218
219                 ('WriteEnter', dict(kind='op', op='Enter')),
220                 ('WriteEnter_data', dict(kind='data')),
221
222                 ('Exit', dict(kind='op', op='Exit')),
223                 ('Exit_data', dict(kind='data')),
224                 #
225                 ('TensorArrayRead', dict(op='TensorArrayReadV3')),
226                 ('TensorArrayRead_data', dict(kind='data')),
227             ],
228             edges=[
229                 ('TensorArray', 'TensorArray_data'),
230                 ('TensorArray', 'TensorArray_flow_data'),
231                 ('TensorArray_flow_data', 'Enter_2'),
232                 ('TensorArray_data', 'WriteEnter'),
233
234
235                 ('Enter_2', 'Enter_2_data'),
236                 ('Enter_2_data', 'Merge_2'),
237                 ('Merge_2', 'Merge_2_data'),
238                 ('Merge_2_data', 'Switch_2'),
239                 ('Switch_2', 'Switch_2_data'),
240                 ('Switch_2', 'Switch_2_data_exit'),
241                 ('Switch_2_data', 'Identity_2'),
242                 ('Identity_2', 'Identity_2_data'),
243
244                 ('Switch_2_data_exit', 'Exit'),
245                 ('Exit', 'Exit_data'),
246                 ('Exit_data', 'TensorArrayRead'),
247
248                 ('WriteEnter', 'WriteEnter_data'),
249                 ('WriteEnter_data', 'TensorArrayWrite', {'in': 0}),
250
251                 ('Identity_2_data', 'TensorArrayWrite', {'in': 3}),
252                 #
253                 ('TensorArrayWrite', 'TensorArrayWrite_data'),
254                 ('TensorArrayWrite_data', 'NextIteration'),
255                 ('Condition_data', 'Switch_2'),
256                 #
257                 ('TensorArray_data', 'TensorArrayRead'),
258                 ('TensorArrayRead', 'TensorArrayRead_data'),
259                 ('NextIteration', 'NextIteration_data'),
260                 ('NextIteration_data', 'Merge_2'),
261             ],
262         )
263
264     @staticmethod
265     def replace_pattern(graph: Graph, match: dict):
266         log.debug('================== SimpleOutputFind ===============')
267         assert match['WriteEnter_data'].value is not None
268
269         index = match['TensorArrayWrite'].in_node(1)
270         value = match['TensorArrayWrite'].in_node(2)
271
272         # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
273         # condition)
274         output = TensorIteratorOutput(graph, dict(
275                                                   external_port_id=str(match['WriteEnter_data'].value),
276                                                   internal_layer_id=value.id,
277                                                   name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_'
278                                                   ))
279         output.create_node_with_data(inputs=[value, index],
280                                      data_nodes=[match['TensorArrayRead_data']])
281
282         # Delete useless nodes
283         safe_nodes = ['TensorArrayRead_data', 'Condition_data']
284         nodes_for_remove = []
285         for node in match.keys():
286             if node not in safe_nodes:
287                 nodes_for_remove.append(match[node].id)
288         graph.remove_nodes_from(nodes_for_remove)