2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.ops.TensorIterator_ops import TensorIteratorOutput
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
24 class SmartOutputMatcher(MiddleReplacementPattern):
26 This pattern match partitioned outputs for TensorIterator in dynamic_rnn loops in TF.
27 The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
28 (data nodes is marked by (data)):
31 Flow(data) Handle(data)--------------------------------------------------------------- |
34 Enter -> Merge -> Switch -> Exit -> TensorArraySize -> Range(0;1) -> TensorArrayGather
37 | ---------------------------------------------
39 --------> Identity -> TensorArrayWrite -> NextIteration
42 graph_condition = [lambda graph: graph.graph['is_cyclic']]
45 from extensions.middle.TensorIteratorInput import SmartInputMatcher
46 return [SmartInputMatcher]
49 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
50 return [TensorIteratorMerge]
56 ('TensorArray', dict(kind='op', op='TensorArrayV3')),
57 ('TensorArray_data', dict(kind='data')),
58 ('TensorArray_flow_data', dict(kind='data')),
59 ('TensorArrayGather', dict(kind='op', op='TensorArrayGatherV3')),
60 ('TensorArrayGather_data', dict(kind='data')),
61 ('range', dict(kind='op', op='Range')),
62 ('range_data', dict(kind='data')),
63 ('size', dict(kind='op', op='TensorArraySizeV3')),
64 ('size_data', dict(kind='data')),
65 ('start', dict(kind='op', op='Const')),
66 ('start_data', dict(kind='data')),
67 ('delta', dict(kind='op', op='Const')),
68 ('delta_data', dict(kind='data')),
69 ('TensorArrayWrite', dict(kind='op', op='TensorArrayWriteV3')),
70 ('TensorArrayWrite_data', dict(kind='data')),
71 ('NextIteration', dict(kind='op', op='NextIteration')),
72 ('Condition_data', dict(kind='data')),
73 ('Identity_2_data', dict(kind='data')),
74 ('Identity_2', dict(kind='op', op='Identity')),
75 ('Switch_2', dict(kind='op', op='Switch')),
76 ('Switch_2_data', dict(kind='data')),
77 ('Switch_2_data_exit', dict(kind='data')),
78 ('Merge_2', dict(kind='op', op='Merge')),
79 ('Merge_2_data', dict(kind='data')),
80 ('Enter_2', dict(kind='op', op='Enter')),
81 ('Enter_2_data', dict(kind='data')),
82 ('WriteEnter', dict(kind='op', op='Enter')),
83 ('WriteEnter_data', dict(kind='data')),
84 ('Exit', dict(kind='op', op='Exit')),
85 ('Exit_data', dict(kind='data')),
88 ('TensorArray', 'TensorArray_data'),
89 ('TensorArray', 'TensorArray_flow_data'),
90 ('TensorArray_flow_data', 'Enter_2'),
91 ('TensorArray_data', 'WriteEnter'),
92 ('TensorArray_data', 'TensorArrayGather'),
93 ('TensorArrayGather', 'TensorArrayGather_data'),
94 ('TensorArray_data', 'size'),
96 ('size', 'size_data'),
97 ('start', 'start_data'),
98 ('delta', 'delta_data'),
100 ('size_data', 'range', {'in': 1}),
101 ('start_data', 'range', {'in': 0}),
102 ('delta_data', 'range', {'in': 2}),
103 ('range', 'range_data'),
104 ('range_data', 'TensorArrayGather'),
106 ('Enter_2', 'Enter_2_data'),
107 ('Enter_2_data', 'Merge_2'),
108 ('Merge_2', 'Merge_2_data'),
109 ('Merge_2_data', 'Switch_2'),
110 ('Switch_2', 'Switch_2_data'),
111 ('Switch_2', 'Switch_2_data_exit'),
112 ('Switch_2_data', 'Identity_2'),
113 ('Identity_2', 'Identity_2_data'),
115 ('Switch_2_data_exit', 'Exit'),
116 ('Exit', 'Exit_data'),
117 ('Exit_data', 'size'),
118 ('Exit_data', 'TensorArrayGather'),
120 ('WriteEnter', 'WriteEnter_data'),
121 ('WriteEnter_data', 'TensorArrayWrite', {'in': 0}),
123 ('Identity_2_data', 'TensorArrayWrite', {'in': 3}),
125 ('TensorArrayWrite', 'TensorArrayWrite_data'),
126 ('TensorArrayWrite_data', 'NextIteration'),
127 ('Condition_data', 'Switch_2'),
132 def replace_pattern(graph: Graph, match: dict):
133 log.debug('================== SmartOutputFind ===============')
135 assert match['WriteEnter_data'].value is not None
136 assert match['start_data']['value'] == 0 and match['delta_data']['value'] == 1
138 ta_size = match['TensorArray'].in_node()
140 index = match['TensorArrayWrite'].in_node(1)
141 value = match['TensorArrayWrite'].in_node(2)
143 # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
145 output = TensorIteratorOutput(graph, dict(axis=0, start=None, stride=None, part_size=None,
146 external_port_id=str(match['WriteEnter_data'].value),
147 internal_layer_id=value.id,
148 name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_'
150 output.create_node_with_data(inputs=[ta_size, value, index],
151 data_nodes=[match['TensorArrayGather_data']])
153 # Delete useless nodes
154 safe_nodes = ['TensorArrayGather_data', 'Condition_data']
155 nodes_for_remove = []
156 for node in match.keys():
157 if node not in safe_nodes:
158 nodes_for_remove.append(match[node].id)
159 graph.remove_nodes_from(nodes_for_remove)
162 class SimpleOutputMatcher(MiddleReplacementPattern):
164 This pattern match partitioned outputs for TensorIterator in dynamic_rnn loops in TF.
165 The structure of pattern without Data nodes between ops. Every node is named as op attribute of this node
166 (data nodes is marked by (data)):
169 Flow(data) Handle(data)------------------------------
172 Enter -> Merge -> Switch -> Exit -> TensorArrayRead
177 --------> Identity -> TensorArrayWrite -> NextIteration
180 graph_condition = [lambda graph: graph.graph['is_cyclic']]
183 return [SmartOutputMatcher]
185 def run_before(self):
186 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
187 from extensions.middle.TensorIteratorCondition import LoopConditionMatcher
188 return [TensorIteratorMerge, LoopConditionMatcher]
194 ('TensorArray', dict(kind='op', op='TensorArrayV3')),
195 ('TensorArray_data', dict(kind='data')),
196 ('TensorArray_flow_data', dict(kind='data')),
198 ('TensorArrayWrite', dict(kind='op', op='TensorArrayWriteV3')),
199 ('TensorArrayWrite_data', dict(kind='data')),
201 ('NextIteration', dict(kind='op', op='NextIteration')),
202 ('NextIteration_data', dict(kind='data')),
204 ('Condition_data', dict(kind='data')),
206 ('Identity_2', dict(kind='op', op='Identity')),
207 ('Identity_2_data', dict(kind='data')),
209 ('Switch_2', dict(kind='op', op='Switch')),
210 ('Switch_2_data', dict(kind='data')),
211 ('Switch_2_data_exit', dict(kind='data')),
213 ('Merge_2', dict(kind='op', op='Merge')),
214 ('Merge_2_data', dict(kind='data')),
216 ('Enter_2', dict(kind='op', op='Enter')),
217 ('Enter_2_data', dict(kind='data')),
219 ('WriteEnter', dict(kind='op', op='Enter')),
220 ('WriteEnter_data', dict(kind='data')),
222 ('Exit', dict(kind='op', op='Exit')),
223 ('Exit_data', dict(kind='data')),
225 ('TensorArrayRead', dict(op='TensorArrayReadV3')),
226 ('TensorArrayRead_data', dict(kind='data')),
229 ('TensorArray', 'TensorArray_data'),
230 ('TensorArray', 'TensorArray_flow_data'),
231 ('TensorArray_flow_data', 'Enter_2'),
232 ('TensorArray_data', 'WriteEnter'),
235 ('Enter_2', 'Enter_2_data'),
236 ('Enter_2_data', 'Merge_2'),
237 ('Merge_2', 'Merge_2_data'),
238 ('Merge_2_data', 'Switch_2'),
239 ('Switch_2', 'Switch_2_data'),
240 ('Switch_2', 'Switch_2_data_exit'),
241 ('Switch_2_data', 'Identity_2'),
242 ('Identity_2', 'Identity_2_data'),
244 ('Switch_2_data_exit', 'Exit'),
245 ('Exit', 'Exit_data'),
246 ('Exit_data', 'TensorArrayRead'),
248 ('WriteEnter', 'WriteEnter_data'),
249 ('WriteEnter_data', 'TensorArrayWrite', {'in': 0}),
251 ('Identity_2_data', 'TensorArrayWrite', {'in': 3}),
253 ('TensorArrayWrite', 'TensorArrayWrite_data'),
254 ('TensorArrayWrite_data', 'NextIteration'),
255 ('Condition_data', 'Switch_2'),
257 ('TensorArray_data', 'TensorArrayRead'),
258 ('TensorArrayRead', 'TensorArrayRead_data'),
259 ('NextIteration', 'NextIteration_data'),
260 ('NextIteration_data', 'Merge_2'),
265 def replace_pattern(graph: Graph, match: dict):
266 log.debug('================== SimpleOutputFind ===============')
267 assert match['WriteEnter_data'].value is not None
269 index = match['TensorArrayWrite'].in_node(1)
270 value = match['TensorArrayWrite'].in_node(2)
272 # axis == 0 because in TensorArray we ALWAYS iterate over 0 axis, other params will be fill later (with
274 output = TensorIteratorOutput(graph, dict(
275 external_port_id=str(match['WriteEnter_data'].value),
276 internal_layer_id=value.id,
277 name=match['TensorArrayWrite'].name + '/TensorIteratorOutput_'
279 output.create_node_with_data(inputs=[value, index],
280 data_nodes=[match['TensorArrayRead_data']])
282 # Delete useless nodes
283 safe_nodes = ['TensorArrayRead_data', 'Condition_data']
284 nodes_for_remove = []
285 for node in match.keys():
286 if node not in safe_nodes:
287 nodes_for_remove.append(match[node].id)
288 graph.remove_nodes_from(nodes_for_remove)