Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TensorIteratorLSTMToLSTMSequence.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from extensions.middle.TF_lstm_cell_to_generic import TensorFlowLSTMtoGeneric
18 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
19 from mo.graph.graph import Graph
20 from mo.middle.pattern_match import find_isomorphisms
21 from mo.middle.replacement import MiddleReplacementPattern
22 from mo.utils.error import Error
23 from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize, permute_before_and_after
24
25
26 class TensorIteratorLSTM(MiddleReplacementPattern):
27     """ Detects TensorIterator with LSTMCell of supported form.
28
29         Collect original operation names of supported LSTMCells in
30         the list LSTMCell.instances_supported_by_IE. It will be used at the second
31         round of the network translation. Mark all supported LSTMCell with flag
32         supported_by_IE to have a chance to detect all not-supported instances
33         in a separate pass.
34     """
35
36     enabled = False
37
38     def run_after(self):
39         return [TensorIteratorMerge, ONNXRNNSequenceNormalize, TensorFlowLSTMtoGeneric]
40
41     def pattern(self):
42         return dict(
43             nodes=[
44                 ('ti', dict(kind='op', op='TensorIterator')),
45             ],
46             edges=[
47             ]
48         )
49
50     @staticmethod
51     def replace_pattern(graph: Graph, match: dict):
52         nodes = [
53             ('input_unsqueezed'),
54             ('squeeze', dict(op='Reshape')),
55             ('input_squeezed'),
56             ('input_hidden'),
57             ('input_cell'),
58             ('weights'),
59             ('biases'),
60
61             ('lstm', dict(op='LSTMCell')),
62
63             ('output_hidden'),
64             ('output_cell'),
65             ('unsqueeze', dict(op='Reshape')),
66             ('output_unsqueezed'),
67         ]
68         edges = [
69             ('input_unsqueezed', 'squeeze'),
70             ('squeeze', 'input_squeezed'),
71
72             ('input_squeezed', 'lstm', {'in': 0}),
73             ('input_hidden', 'lstm', {'in': 1}),
74             ('input_cell', 'lstm', {'in': 2}),
75             ('weights', 'lstm', {'in': 3}),
76             ('biases', 'lstm', {'in': 4}),
77
78             ('lstm', 'output_hidden', {'out': 0}),
79             ('lstm', 'output_cell', {'out': 1}),
80
81             ('output_hidden', 'unsqueeze'),
82             ('unsqueeze', 'output_unsqueezed'),
83         ]
84         ti = match['ti']
85         isomorphisms = find_isomorphisms(ti.body, nodes, edges)
86         if len(list(isomorphisms)) != 1:
87             raise Error('Unsupported TensorIterator layer {} was found: either its body, ports or '
88                         'edges are not supported by Inference Engine. '
89                         'Only TensorIterator with LSTMCell in a body of strict form is supported. '
90                         'Please modify the original network '
91                         'to meet the requirements.'.format(ti.soft_get('name')))
92         body_match = isomorphisms[0]
93         if body_match['input_hidden'].has_valid('value') or body_match['input_cell'].has_valid('value'):
94             raise Error('Unsupported TensorIterator layer {} was found: initial hidden and/or cell states '
95                         'for LSTMCell are constants. This is not supported. '
96                         'Only TensorIterator with LSTMCell in a body of strict form is supported. '
97                         'Please modify the original network '
98                         'to meet the requirements.'.format(ti.soft_get('name')))
99         # TODO Additional checks for port indices