2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.graph.graph import copy_node
20 from mo.utils.error import Error
21 from mo.middle.pattern_match import find_isomorphisms
22 from mo.middle.replacement import MiddleReplacementPattern
23 from extensions.ops.lstm_sequence import LSTMSequence
24 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
25 from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize, permute_before_and_after
26 from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
27 from extensions.middle.TF_lstm_cell_to_generic import TensorFlowLSTMtoGeneric
30 class TensorIteratorLSTM(MiddleReplacementPattern):
31 """ Detects TensorIterator with LSTMCell of supported form.
33 Collect original operation names of supported LSTMCells in
34 the list LSTMCell.instances_supported_by_IE. It will be used at the second
35 round of the network translation. Mark all supported LSTMCell with flag
36 supported_by_IE to have a chance to detect all not-supported instances
43 return [TensorIteratorMerge, LSTMSequenceNormalize, LSTMSequenceTensorIterator, TensorFlowLSTMtoGeneric]
48 ('ti', dict(kind='op', op='TensorIterator')),
55 def replace_pattern(graph: nx.MultiDiGraph, match: dict):
58 ('squeeze', dict(op='Reshape')),
65 ('lstm', dict(op='LSTMCell')),
69 ('unsqueeze', dict(op='Reshape')),
70 ('output_unsqueezed'),
73 ('input_unsqueezed', 'squeeze'),
74 ('squeeze', 'input_squeezed'),
76 ('input_squeezed', 'lstm', {'in': 0}),
77 ('input_hidden', 'lstm', {'in': 1}),
78 ('input_cell', 'lstm', {'in': 2}),
79 ('weights', 'lstm', {'in': 3}),
80 ('biases', 'lstm', {'in': 4}),
82 ('lstm', 'output_hidden', {'out': 0}),
83 ('lstm', 'output_cell', {'out': 1}),
85 ('output_hidden', 'unsqueeze'),
86 ('unsqueeze', 'output_unsqueezed'),
89 isomorphisms = find_isomorphisms(ti.body, nodes, edges)
90 if len(list(isomorphisms)) != 1:
91 raise Error('Unsupported TensorIterator layer {} was found: either its body, ports or '
92 'edges are not supported by Inference Engine. '
93 'Only TensorIterator with LSTMCell in a body of strict form is supported. '
94 'Please modify the original network '
95 'to meet the requirements.'.format(ti.soft_get('name')))
96 body_match = isomorphisms[0]
97 if body_match['input_hidden'].has_valid('value') or body_match['input_cell'].has_valid('value'):
98 raise Error('Unsupported TensorIterator layer {} was found: initial hidden and/or cell states '
99 'for LSTMCell are constants. This is not supported. '
100 'Only TensorIterator with LSTMCell in a body of strict form is supported. '
101 'Please modify the original network '
102 'to meet the requirements.'.format(ti.soft_get('name')))
103 # TODO Additional checks for port indices
104 if body_match['lstm'].has_valid('mark_supported_by_IE'):
105 body_match['lstm'].mark_supported_by_IE(body_match['lstm'])
108 class CheckUnsupportedLSTMCell(MiddleReplacementPattern):
109 """ Finds all unsupported LSTMCell.
111 Initiates the second translation round if find any not supported LSTMCell instances.
117 return [TensorIteratorLSTM]
122 ('lstm', dict(op='LSTMCell')),
129 def replace_pattern(graph: nx.MultiDiGraph, match: dict):
130 lstmcell = match['lstm']
131 if lstmcell.has_valid('finalize_first_round'):
132 lstmcell.finalize_first_round()
133 if not lstmcell.has_and_set('supported_by_IE'):
134 # this is a signal for the main translation pipeline to repeat the entire conversion process
135 graph.graph['repeat_conversion'] = True
136 # in case when there is no lstmcell.finalize_first_round then this cell wasn't created with the pattern
137 # (for example in ONNX) and we don't initiate the second round.