2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.ops.TensorIterator_ops import TensorIteratorCondition
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
25 class LoopConditionMatcher(MiddleReplacementPattern):
27 This pattern match condition for TensorIterator in while loops in TF.
28 The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
29 (data nodes is marked by (data)):
33 Const -> Enter -> Merge ---------------------> Switch -> Identity -> Add -> NextIteration
37 Maximum -> Minimum -> Enter-| | |
39 Shape -> StridedSlice -> Enter -| LogicalAnd --> LoopCond (data)
43 Const -> Enter -> Merge ---------------------> Switch -> Identity -> Add -> NextIteration
49 graph_condition = [lambda graph: graph.graph['is_cyclic']]
55 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
56 return [TensorIteratorMerge]
60 log.debug('+++++++++++++++ ConditionMatching ++++++++++++++++')
63 ('Enter_1_less', dict(kind='op', op='Enter')),
64 ('Strided_slice', dict(kind='op', op='StridedSlice')),
65 ('Strided_slice_data', dict(kind='data')),
66 ('Enter_1_less_data', dict(kind='data')),
68 ('Less_1', dict(kind='op', op='Less')),
69 ('Merge_1', dict(kind='op', op='Merge')),
70 ('Merge_1_data', dict(kind='data')),
71 ('Less_1_data', dict(kind='data')),
73 ('Less_2', dict(kind='op', op='Less')),
74 ('Merge_2', dict(kind='op', op='Merge')),
75 ('Merge_2_data', dict(kind='data')),
76 ('Less_2_data', dict(kind='data')),
78 ('Enter_2_less', dict(kind='op', op='Enter')),
79 ('Enter_2_less_data', dict(kind='data')),
80 ('minimum_data', dict(kind='data')),
82 ('and', dict(kind='op', op='LogicalAnd')),
83 ('and_data', dict(kind='data')),
84 ('loop_cond', dict(kind='op', op='LoopCond')),
85 ('loop_cond_data', dict(kind='data')),
87 ('init_1', dict(kind='op', op='Const')),
88 ('init_1_data', dict(kind='data')),
89 ('Enter_1', dict(kind='op', op='Enter')),
90 ('Enter_1_data', dict(kind='data')),
92 ('init_2', dict(kind='op', op='Const')),
93 ('init_2_data', dict(kind='data')),
94 ('Enter_2', dict(kind='op', op='Enter')),
95 ('Enter_2_data', dict(kind='data')),
97 ('Switch_1', dict(kind='op', op='Switch')),
98 ('Switch_1_data', dict(kind='data')),
99 ('Identity_1', dict(kind='op', op='Identity')),
100 ('Identity_1_data', dict(kind='data')),
101 ('add_1', dict(kind='op', op='Add')),
102 ('add_1_y', dict(kind='op', op='Const')),
103 ('add_1_y_data', dict(kind='data')),
104 ('add_1_data', dict(kind='data')),
105 ('NextIteration_1', dict(kind='op', op='NextIteration')),
107 ('Switch_2', dict(kind='op', op='Switch')),
108 ('Switch_2_data', dict(kind='data')),
109 ('Identity_2', dict(kind='op', op='Identity')),
110 ('Identity_2_data', dict(kind='data')),
111 ('add_2', dict(kind='op', op='Add')),
112 ('add_2_y', dict(kind='op', op='Const')),
113 ('add_2_y_data', dict(kind='data')),
114 ('add_2_data', dict(kind='data')),
115 ('NextIteration_2', dict(kind='op', op='NextIteration')),
119 ('Strided_slice', 'Strided_slice_data'),
120 ('Strided_slice_data', 'Enter_1_less'),
121 ('Enter_1_less', 'Enter_1_less_data'),
122 ('Enter_1_less_data', 'Less_1'),
123 ('Less_1', 'Less_1_data'),
124 ('Less_1_data', 'and'),
127 ('and_data', 'loop_cond'),
128 ('loop_cond', 'loop_cond_data'),
129 ('loop_cond_data', 'Switch_1'),
130 ('loop_cond_data', 'Switch_2'),
132 ('init_1', 'init_1_data'),
133 ('init_1_data', 'Enter_1'),
134 ('Enter_1', 'Enter_1_data'),
135 ('Enter_1_data', 'Merge_1'),
136 ('Merge_1', 'Merge_1_data'),
137 ('Merge_1_data', 'Less_1'),
139 ('Merge_1_data', 'Switch_1'),
140 ('Switch_1', 'Switch_1_data'),
141 ('Switch_1_data', 'Identity_1'),
142 ('Identity_1', 'Identity_1_data'),
143 ('Identity_1_data', 'add_1'),
144 ('add_1_y', 'add_1_y_data'),
145 ('add_1_y_data', 'add_1'),
146 ('add_1', 'add_1_data'),
147 ('add_1_data', 'NextIteration_1'),
149 ('Merge_2_data', 'Switch_2'),
150 ('Switch_2', 'Switch_2_data'),
151 ('Switch_2_data', 'Identity_2'),
152 ('Identity_2', 'Identity_2_data'),
153 ('Identity_2_data', 'add_2'),
154 ('add_2_y', 'add_2_y_data'),
155 ('add_2_y_data', 'add_2'),
156 ('add_2', 'add_2_data'),
157 ('add_2_data', 'NextIteration_2'),
159 ('minimum_data', 'Enter_2_less'),
160 ('Enter_2_less', 'Enter_2_less_data'),
161 ('Enter_2_less_data', 'Less_2'),
163 ('init_2', 'init_2_data'),
164 ('init_2_data', 'Enter_2'),
165 ('Enter_2', 'Enter_2_data'),
166 ('Enter_2_data', 'Merge_2'),
168 ('Merge_2', 'Merge_2_data'),
169 ('Merge_2_data', 'Less_2'),
170 ('Less_2', 'Less_2_data'),
171 ('Less_2_data', 'and'),
176 def looking_for_iteration_counter(graph: Graph, match: dict):
177 types = ['TensorIteratorInput', 'TensorIteratorOutput']
178 candidates = np.array([match['Identity_1_data'], match['Identity_2_data']])
179 results = np.array([False for i in range(len(candidates))])
180 for i, candidat in enumerate(candidates):
181 for node in candidat.out_nodes():
182 if node['op'] in types:
184 assert not np.all(results)
185 assert sum(results) == 1
186 return candidates[results == True][0]
188 def replace_pattern(self, graph: Graph, match: dict):
189 log.debug('================== ConditionFind ===============')
191 init_1 = match['init_1_data'].value
192 assert init_1 is not None
196 init_2 = match['init_2_data'].value
197 assert init_2 is not None
201 assert match['add_1_y_data'].value is not None
202 step_1 = int(match['add_1_y_data'].value)
205 assert match['add_2_y_data'].value is not None
206 step_2 = int(match['add_2_y_data'].value)
208 match['loop_cond_data'].value = None
209 match['Identity_2_data'].value = None
211 # Create condition node and delete all useless nodes from condition pattern
212 loop_condiiton = match['loop_cond_data']
213 iterator_data = self.looking_for_iteration_counter(graph, match)
215 condition_attrs = dict(time=dict(init=init_2, step=step_2), iter=dict(init=init_1, step=step_1),
216 name=match['loop_cond'].name + '/TensorIteratorCondition_')
217 condition = TensorIteratorCondition(graph, attrs=condition_attrs)
218 condition.create_node_with_data(inputs=[match['Strided_slice_data'], match['minimum_data']],
219 data_nodes=[loop_condiiton, iterator_data])
221 # Delete useless nodes
222 safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Identity_2_data', 'Strided_slice', 'Strided_slice_data',
223 'minimum', 'minimum_data']
224 nodes_for_remove = []
225 for node in match.keys():
226 if node not in safe_nodes:
227 nodes_for_remove.append(match[node].id)
228 graph.remove_nodes_from(nodes_for_remove)
231 class SimpleConditionMatcher(MiddleReplacementPattern):
233 graph_condition = [lambda graph: graph.graph['is_cyclic']]
236 return [LoopConditionMatcher]
238 def run_before(self):
239 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
240 return [TensorIteratorMerge]
244 log.debug('+++++++++++++++ SimpleConditionMatching ++++++++++++++++')
247 ('Enter_1_less', dict(kind='op', op='Enter')),
248 ('Strided_slice', dict(kind='op', op='StridedSlice')),
249 ('Strided_slice_data', dict(kind='data')),
250 ('Enter_1_less_data', dict(kind='data')),
252 ('Less_1', dict(kind='op', op='Less')),
253 ('Merge_1', dict(kind='op', op='Merge')),
254 ('Merge_1_data', dict(kind='data')),
255 ('Less_1_data', dict(kind='data')),
257 ('loop_cond', dict(kind='op', op='LoopCond')),
258 ('loop_cond_data', dict(kind='data')),
260 ('init_1', dict(kind='op', op='Const')),
261 ('init_1_data', dict(kind='data')),
262 ('Enter_1', dict(kind='op', op='Enter')),
263 ('Enter_1_data', dict(kind='data')),
265 ('Switch_1', dict(kind='op', op='Switch')),
266 ('Switch_1_data', dict(kind='data')),
267 ('Identity_1', dict(kind='op', op='Identity')),
268 ('Identity_1_data', dict(kind='data')),
269 ('add_1', dict(kind='op', op='Add')),
270 ('add_1_y', dict(kind='op', op='Const')),
271 ('add_1_y_data', dict(kind='data')),
272 ('add_1_data', dict(kind='data')),
273 ('NextIteration_1', dict(kind='op', op='NextIteration')),
276 ('Strided_slice', 'Strided_slice_data'),
277 ('Strided_slice_data', 'Enter_1_less'),
278 ('Enter_1_less', 'Enter_1_less_data'),
279 ('Enter_1_less_data', 'Less_1'),
280 ('Less_1', 'Less_1_data'),
281 ('Less_1_data', 'loop_cond'),
283 ('loop_cond', 'loop_cond_data'),
284 ('loop_cond_data', 'Switch_1'),
286 ('init_1', 'init_1_data'),
287 ('init_1_data', 'Enter_1'),
288 ('Enter_1', 'Enter_1_data'),
289 ('Enter_1_data', 'Merge_1'),
290 ('Merge_1', 'Merge_1_data'),
291 ('Merge_1_data', 'Less_1'),
293 ('Merge_1_data', 'Switch_1'),
294 ('Switch_1', 'Switch_1_data'),
295 ('Switch_1_data', 'Identity_1'),
296 ('Identity_1', 'Identity_1_data'),
297 ('Identity_1_data', 'add_1'),
298 ('add_1_y', 'add_1_y_data'),
299 ('add_1_y_data', 'add_1'),
300 ('add_1', 'add_1_data'),
301 ('add_1_data', 'NextIteration_1'),
307 def replace_pattern(graph: Graph, match: dict):
308 log.debug('================== SimpleConditionFind ===============')
310 init_1 = match['init_1_data'].value
311 assert init_1 is not None
315 assert match['add_1_y_data'].value is not None
316 step_1 = int(match['add_1_y_data'].value)
318 match['loop_cond_data'].value = None
320 # Create condition node and delete all useless nodes from condition pattern
321 condition_attrs = dict(iter=dict(init=init_1, step=step_1),
322 name=match['loop_cond'].name + '/TensorIteratorCondition_')
323 condition = TensorIteratorCondition(graph, attrs=condition_attrs)
324 condition.create_node_with_data(inputs=[match['Strided_slice_data']],
325 data_nodes=[match['loop_cond_data'], match['Identity_1_data']])
327 # Delete useless nodes
328 safe_nodes = ['loop_cond_data', 'Identity_1_data', 'Strided_slice', 'Strided_slice_data']
329 nodes_for_remove = []
330 for node in match.keys():
331 if node not in safe_nodes:
332 nodes_for_remove.append(match[node].id)
333 graph.remove_nodes_from(nodes_for_remove)