Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorCondition.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 from extensions.ops.TensorIterator_ops import TensorIteratorCondition
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
22 import numpy as np
23
24
25 class LoopConditionMatcher(MiddleReplacementPattern):
26     """
27     This pattern match condition for TensorIterator in while loops in TF.
28     The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
29     (data nodes is marked by (data)):
30                                                                    Const----
31                                                                             |
32                                                                             v
33     Const -> Enter -> Merge ---------------------> Switch -> Identity ->  Add -> NextIteration
34                         |                              ^
35                         ---> Less ----|                |
36                                 ^     |                |
37     Maximum -> Minimum -> Enter-|     |                |
38                  ^                    v                |
39 Shape -> StridedSlice -> Enter -|    LogicalAnd --> LoopCond (data)
40                                 v     ^                |
41                         ---> Less ----|                |
42                         |                              v
43     Const -> Enter -> Merge ---------------------> Switch -> Identity ->  Add -> NextIteration
44                                                                             ^
45                                                                             |
46                                                                    Const----
47     """
48     enabled = True
49     graph_condition = [lambda graph: graph.graph['is_cyclic']]
50
51     def run_after(self):
52         return []
53
54     def run_before(self):
55         from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
56         return [TensorIteratorMerge]
57
58     @staticmethod
59     def pattern():
60         log.debug('+++++++++++++++ ConditionMatching ++++++++++++++++')
61         return dict(
62             nodes=[
63                 ('Enter_1_less', dict(kind='op', op='Enter')),
64                 ('Strided_slice', dict(kind='op', op='StridedSlice')),
65                 ('Strided_slice_data', dict(kind='data')),
66                 ('Enter_1_less_data', dict(kind='data')),
67
68                 ('Less_1', dict(kind='op', op='Less')),
69                 ('Merge_1', dict(kind='op', op='Merge')),
70                 ('Merge_1_data', dict(kind='data')),
71                 ('Less_1_data', dict(kind='data')),
72
73                 ('Less_2', dict(kind='op', op='Less')),
74                 ('Merge_2', dict(kind='op', op='Merge')),
75                 ('Merge_2_data', dict(kind='data')),
76                 ('Less_2_data', dict(kind='data')),
77
78                 ('Enter_2_less', dict(kind='op', op='Enter')),
79                 ('Enter_2_less_data', dict(kind='data')),
80                 ('minimum_data', dict(kind='data')),
81
82                 ('and', dict(kind='op', op='LogicalAnd')),
83                 ('and_data', dict(kind='data')),
84                 ('loop_cond', dict(kind='op', op='LoopCond')),
85                 ('loop_cond_data', dict(kind='data')),
86
87                 ('init_1', dict(kind='op', op='Const')),
88                 ('init_1_data', dict(kind='data')),
89                 ('Enter_1', dict(kind='op', op='Enter')),
90                 ('Enter_1_data', dict(kind='data')),
91
92                 ('init_2', dict(kind='op', op='Const')),
93                 ('init_2_data', dict(kind='data')),
94                 ('Enter_2', dict(kind='op', op='Enter')),
95                 ('Enter_2_data', dict(kind='data')),
96
97                 ('Switch_1', dict(kind='op', op='Switch')),
98                 ('Switch_1_data', dict(kind='data')),
99                 ('Identity_1', dict(kind='op', op='Identity')),
100                 ('Identity_1_data', dict(kind='data')),
101                 ('add_1', dict(kind='op', op='Add')),
102                 ('add_1_y', dict(kind='op', op='Const')),
103                 ('add_1_y_data', dict(kind='data')),
104                 ('add_1_data', dict(kind='data')),
105                 ('NextIteration_1', dict(kind='op', op='NextIteration')),
106
107                 ('Switch_2', dict(kind='op', op='Switch')),
108                 ('Switch_2_data', dict(kind='data')),
109                 ('Identity_2', dict(kind='op', op='Identity')),
110                 ('Identity_2_data', dict(kind='data')),
111                 ('add_2', dict(kind='op', op='Add')),
112                 ('add_2_y', dict(kind='op', op='Const')),
113                 ('add_2_y_data', dict(kind='data')),
114                 ('add_2_data', dict(kind='data')),
115                 ('NextIteration_2', dict(kind='op', op='NextIteration')),
116
117             ],
118             edges=[
119                 ('Strided_slice', 'Strided_slice_data'),
120                 ('Strided_slice_data', 'Enter_1_less'),
121                 ('Enter_1_less', 'Enter_1_less_data'),
122                 ('Enter_1_less_data', 'Less_1'),
123                 ('Less_1', 'Less_1_data'),
124                 ('Less_1_data', 'and'),
125
126                 ('and', 'and_data'),
127                 ('and_data', 'loop_cond'),
128                 ('loop_cond', 'loop_cond_data'),
129                 ('loop_cond_data', 'Switch_1'),
130                 ('loop_cond_data', 'Switch_2'),
131
132                 ('init_1', 'init_1_data'),
133                 ('init_1_data', 'Enter_1'),
134                 ('Enter_1', 'Enter_1_data'),
135                 ('Enter_1_data', 'Merge_1'),
136                 ('Merge_1', 'Merge_1_data'),
137                 ('Merge_1_data', 'Less_1'),
138
139                 ('Merge_1_data', 'Switch_1'),
140                 ('Switch_1', 'Switch_1_data'),
141                 ('Switch_1_data', 'Identity_1'),
142                 ('Identity_1', 'Identity_1_data'),
143                 ('Identity_1_data', 'add_1'),
144                 ('add_1_y', 'add_1_y_data'),
145                 ('add_1_y_data', 'add_1'),
146                 ('add_1', 'add_1_data'),
147                 ('add_1_data', 'NextIteration_1'),
148
149                 ('Merge_2_data', 'Switch_2'),
150                 ('Switch_2', 'Switch_2_data'),
151                 ('Switch_2_data', 'Identity_2'),
152                 ('Identity_2', 'Identity_2_data'),
153                 ('Identity_2_data', 'add_2'),
154                 ('add_2_y', 'add_2_y_data'),
155                 ('add_2_y_data', 'add_2'),
156                 ('add_2', 'add_2_data'),
157                 ('add_2_data', 'NextIteration_2'),
158
159                 ('minimum_data', 'Enter_2_less'),
160                 ('Enter_2_less', 'Enter_2_less_data'),
161                 ('Enter_2_less_data', 'Less_2'),
162
163                 ('init_2', 'init_2_data'),
164                 ('init_2_data', 'Enter_2'),
165                 ('Enter_2', 'Enter_2_data'),
166                 ('Enter_2_data', 'Merge_2'),
167
168                 ('Merge_2', 'Merge_2_data'),
169                 ('Merge_2_data', 'Less_2'),
170                 ('Less_2', 'Less_2_data'),
171                 ('Less_2_data', 'and'),
172             ],
173         )
174
175     @staticmethod
176     def looking_for_iteration_counter(graph: Graph, match: dict):
177         types = ['TensorIteratorInput', 'TensorIteratorOutput']
178         candidates = np.array([match['Identity_1_data'], match['Identity_2_data']])
179         results = np.array([False for i in range(len(candidates))])
180         for i, candidat in enumerate(candidates):
181             for node in candidat.out_nodes():
182                 if node['op'] in types:
183                     results[i] = True
184         assert not np.all(results)
185         assert sum(results) == 1
186         return candidates[results == True][0]
187
188     def replace_pattern(self, graph: Graph, match: dict):
189         log.debug('================== ConditionFind ===============')
190         # init_1
191         init_1 = match['init_1_data'].value
192         assert init_1 is not None
193         init_1 = int(init_1)
194
195         # init_2
196         init_2 = match['init_2_data'].value
197         assert init_2 is not None
198         init_2 = int(init_2)
199
200         # step_1
201         assert match['add_1_y_data'].value is not None
202         step_1 = int(match['add_1_y_data'].value)
203
204         # step_2
205         assert match['add_2_y_data'].value is not None
206         step_2 = int(match['add_2_y_data'].value)
207
208         match['loop_cond_data'].value = None
209         match['Identity_2_data'].value = None
210
211         # Create condition node and delete all useless nodes from condition pattern
212         loop_condiiton = match['loop_cond_data']
213         iterator_data = self.looking_for_iteration_counter(graph, match)
214
215         condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1),
216                                name=match['loop_cond'].name + '/TensorIteratorCondition_')
217         condition = TensorIteratorCondition(graph, attrs=condition_attrs)
218         condition.create_node_with_data(inputs=[match['Strided_slice_data'], match['minimum_data']],
219                                         data_nodes=[loop_condiiton, iterator_data])
220
221         # Delete useless nodes
222         safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Identity_2_data',  'Strided_slice', 'Strided_slice_data',
223                       'minimum', 'minimum_data']
224         nodes_for_remove = []
225         for node in match.keys():
226             if node not in safe_nodes:
227                 nodes_for_remove.append(match[node].id)
228         graph.remove_nodes_from(nodes_for_remove)
229
230
231 class SimpleConditionMatcher(MiddleReplacementPattern):
232     enabled = True
233     graph_condition = [lambda graph: graph.graph['is_cyclic']]
234
235     def run_after(self):
236         return [LoopConditionMatcher]
237
238     def run_before(self):
239         from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
240         return [TensorIteratorMerge]
241
242     @staticmethod
243     def pattern():
244         log.debug('+++++++++++++++ SimpleConditionMatching ++++++++++++++++')
245         return dict(
246             nodes=[
247                 ('Enter_1_less', dict(kind='op', op='Enter')),
248                 ('Strided_slice', dict(kind='op', op='StridedSlice')),
249                 ('Strided_slice_data', dict(kind='data')),
250                 ('Enter_1_less_data', dict(kind='data')),
251
252                 ('Less_1', dict(kind='op', op='Less')),
253                 ('Merge_1', dict(kind='op', op='Merge')),
254                 ('Merge_1_data', dict(kind='data')),
255                 ('Less_1_data', dict(kind='data')),
256
257                 ('loop_cond', dict(kind='op', op='LoopCond')),
258                 ('loop_cond_data', dict(kind='data')),
259
260                 ('init_1', dict(kind='op', op='Const')),
261                 ('init_1_data', dict(kind='data')),
262                 ('Enter_1', dict(kind='op', op='Enter')),
263                 ('Enter_1_data', dict(kind='data')),
264
265                 ('Switch_1', dict(kind='op', op='Switch')),
266                 ('Switch_1_data', dict(kind='data')),
267                 ('Identity_1', dict(kind='op', op='Identity')),
268                 ('Identity_1_data', dict(kind='data')),
269                 ('add_1', dict(kind='op', op='Add')),
270                 ('add_1_y', dict(kind='op', op='Const')),
271                 ('add_1_y_data', dict(kind='data')),
272                 ('add_1_data', dict(kind='data')),
273                 ('NextIteration_1', dict(kind='op', op='NextIteration')),
274             ],
275             edges=[
276                 ('Strided_slice', 'Strided_slice_data'),
277                 ('Strided_slice_data', 'Enter_1_less'),
278                 ('Enter_1_less', 'Enter_1_less_data'),
279                 ('Enter_1_less_data', 'Less_1'),
280                 ('Less_1', 'Less_1_data'),
281                 ('Less_1_data', 'loop_cond'),
282
283                 ('loop_cond', 'loop_cond_data'),
284                 ('loop_cond_data', 'Switch_1'),
285
286                 ('init_1', 'init_1_data'),
287                 ('init_1_data', 'Enter_1'),
288                 ('Enter_1', 'Enter_1_data'),
289                 ('Enter_1_data', 'Merge_1'),
290                 ('Merge_1', 'Merge_1_data'),
291                 ('Merge_1_data', 'Less_1'),
292
293                 ('Merge_1_data', 'Switch_1'),
294                 ('Switch_1', 'Switch_1_data'),
295                 ('Switch_1_data', 'Identity_1'),
296                 ('Identity_1', 'Identity_1_data'),
297                 ('Identity_1_data', 'add_1'),
298                 ('add_1_y', 'add_1_y_data'),
299                 ('add_1_y_data', 'add_1'),
300                 ('add_1', 'add_1_data'),
301                 ('add_1_data', 'NextIteration_1'),
302
303             ],
304         )
305
306     @staticmethod
307     def replace_pattern(graph: Graph, match: dict):
308         log.debug('================== SimpleConditionFind ===============')
309         # init_1
310         init_1 = match['init_1_data'].value
311         assert init_1 is not None
312         init_1 = int(init_1)
313
314         # step_1
315         assert match['add_1_y_data'].value is not None
316         step_1 = int(match['add_1_y_data'].value)
317
318         match['loop_cond_data'].value = None
319
320         # Create condition node and delete all useless nodes from condition pattern
321         condition_attrs = dict(iter=dict(init=init_1, step=step_1),
322                                name=match['loop_cond'].name + '/TensorIteratorCondition_')
323         condition = TensorIteratorCondition(graph, attrs=condition_attrs)
324         condition.create_node_with_data(inputs=[match['Strided_slice_data']],
325                                         data_nodes=[match['loop_cond_data'], match['Identity_1_data']])
326
327         # Delete useless nodes
328         safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Strided_slice', 'Strided_slice_data']
329         nodes_for_remove = []
330         for node in match.keys():
331             if node not in safe_nodes:
332                 nodes_for_remove.append(match[node].id)
333         graph.remove_nodes_from(nodes_for_remove)