Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / lstm_tensor_iterator_to_lstm_sequence.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import networkx as nx
18
19 from mo.graph.graph import copy_node
20 from mo.utils.error import Error
21 from mo.middle.pattern_match import find_isomorphisms
22 from mo.middle.replacement import MiddleReplacementPattern
23 from extensions.ops.lstm_sequence import LSTMSequence
24 from extensions.middle.TensorIteratorMerge import TensorIteratorMerge
25 from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize, permute_before_and_after
26 from extensions.middle.lstm_sequence_tensor_iterator import LSTMSequenceTensorIterator
27 from extensions.middle.TF_lstm_cell_to_generic import TensorFlowLSTMtoGeneric
28
29
30 class TensorIteratorLSTM(MiddleReplacementPattern):
31     """ Detects TensorIterator with LSTMCell of supported form.
32
33         Collect original operation names of supported LSTMCells in
34         the list LSTMCell.instances_supported_by_IE. It will be used at the second
35         round of the network translation. Mark all supported LSTMCell with flag
36         supported_by_IE to have a chance to detect all not-supported instances
37         in a separate pass.
38     """
39
40     enabled = False
41
42     def run_after(self):
43         return [TensorIteratorMerge, LSTMSequenceNormalize, LSTMSequenceTensorIterator, TensorFlowLSTMtoGeneric]
44
45     def pattern(self):
46         return dict(
47             nodes=[
48                 ('ti', dict(kind='op', op='TensorIterator')),
49             ],
50             edges=[
51             ]
52         )
53
54     @staticmethod
55     def replace_pattern(graph: nx.MultiDiGraph, match: dict):
56         nodes=[
57             ('input_unsqueezed'),
58             ('squeeze', dict(op='Reshape')),
59             ('input_squeezed'),
60             ('input_hidden'),
61             ('input_cell'),
62             ('weights'),
63             ('biases'),
64
65             ('lstm', dict(op='LSTMCell')),
66
67             ('output_hidden'),
68             ('output_cell'),
69             ('unsqueeze', dict(op='Reshape')),
70             ('output_unsqueezed'),
71         ]
72         edges=[
73             ('input_unsqueezed', 'squeeze'),
74             ('squeeze', 'input_squeezed'),
75
76             ('input_squeezed', 'lstm', {'in': 0}),
77             ('input_hidden', 'lstm', {'in': 1}),
78             ('input_cell', 'lstm', {'in': 2}),
79             ('weights', 'lstm', {'in': 3}),
80             ('biases', 'lstm', {'in': 4}),
81
82             ('lstm', 'output_hidden', {'out': 0}),
83             ('lstm', 'output_cell', {'out': 1}),
84
85             ('output_hidden', 'unsqueeze'),
86             ('unsqueeze', 'output_unsqueezed'),
87         ]
88         ti = match['ti']
89         isomorphisms = find_isomorphisms(ti.body, nodes, edges)
90         if len(list(isomorphisms)) != 1:
91             raise Error('Unsupported TensorIterator layer {} was found: either its body, ports or '
92                         'edges are not supported by Inference Engine. '
93                         'Only TensorIterator with LSTMCell in a body of strict form is supported. '
94                         'Please modify the original network '
95                         'to meet the requirements.'.format(ti.soft_get('name')))
96         body_match = isomorphisms[0]
97         if body_match['input_hidden'].has_valid('value') or body_match['input_cell'].has_valid('value'):
98             raise Error('Unsupported TensorIterator layer {} was found: initial hidden and/or cell states '
99                         'for LSTMCell are constants. This is not supported. '
100                         'Only TensorIterator with LSTMCell in a body of strict form is supported. '
101                         'Please modify the original network '
102                         'to meet the requirements.'.format(ti.soft_get('name')))
103         # TODO Additional checks for port indices
104         if body_match['lstm'].has_valid('mark_supported_by_IE'):
105             body_match['lstm'].mark_supported_by_IE(body_match['lstm'])
106
107
108 class CheckUnsupportedLSTMCell(MiddleReplacementPattern):
109     """ Finds all unsupported LSTMCell.
110
111         Initiates the second translation round if find any not supported LSTMCell instances.
112     """
113
114     enabled = False
115
116     def run_after(self):
117         return [TensorIteratorLSTM]
118
119     def pattern(self):
120         return dict(
121             nodes=[
122                 ('lstm', dict(op='LSTMCell')),
123             ],
124             edges=[
125             ]
126         )
127
128     @staticmethod
129     def replace_pattern(graph: nx.MultiDiGraph, match: dict):
130         lstmcell = match['lstm']
131         if lstmcell.has_valid('finalize_first_round'):
132             lstmcell.finalize_first_round()
133             if not lstmcell.has_and_set('supported_by_IE'):
134                 # this is a signal for the main translation pipeline to repeat the entire conversion process
135                 graph.graph['repeat_conversion'] = True
136         # in case when there is no lstmcell.finalize_first_round then this cell wasn't created with the pattern
137         # (for example in ONNX) and we don't initiate the second round.