Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / lstm_sequence_normalize.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import networkx as nx
18 import numpy as np
19 from copy import deepcopy
20
21 from extensions.middle.decompose_bi_lstm import DecomposeBiLSTM
22 from mo.utils.error import Error
23 from mo.middle.replacement import MiddleReplacementPattern
24 from mo.ops.op import Op
25 from mo.ops.permute import Permute
26 from mo.ops.reshape import Reshape
27 from mo.graph.graph import Node
28
29
30 def inverse_perm(order: np.array):
31     indices = np.empty(order.size, dtype=np.int64)
32     indices[order] = np.arange(order.size)
33     return indices
34
35
36 def permute_before_and_after(inp: Node, middle: Node, out: Node, order):
37     ''' Insert two permutes: before middle node and after middle node.
38
39         The first permute has a given order, the second permute has an
40         inversed order.
41     '''
42
43     permute = Permute(middle.graph, dict(order=np.array(order)))
44
45     edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0])
46     middle.graph.remove_edge(inp.id, middle.id)
47     new_inp = permute.create_node_with_data([inp], dict(name=middle.name + '/InputPermute'))
48     middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs)
49
50     permute = Permute(middle.graph, dict(order=inverse_perm(np.array(order))))
51
52     middle.graph.remove_edge(middle.id, out.id)
53     new_out = Op._create_data_node(middle.graph, name=middle.name + '/WithoutPermute', attrs={'shape': out.shape[order]})
54     middle.graph.add_edge(middle.id, new_out.id, key=0, out=0)
55     permute.create_node_with_data([new_out], dict(name=middle.name + '/OutputPermute'), data_nodes=out)
56
57
58 class LSTMSequenceNormalize(MiddleReplacementPattern):
59     ''' Convert blobs and shapes of ONNX-like LSTM to IE compatible form.
60
61         Fuse W, R and optional B input blobs to weights and biases according
62         to IE LSTM specification. In case of bidirectional LSTM, the resulting
63         blobs are not directly supported by IE, but it will be further processed
64         by a separate transformation to break down to one-directional LSTMs.
65
66         The target form of this operation is not normally covered by a dedicated
67         layer in IE. It should be further transformed to some other layer
68         that are supported by IE. This transformation pass involves weights and
69         shapes processing only.
70
71         Post-conditions:
72
73         Inputs have the following order:
74             0: input data
75             1: weights blob
76             2: biases blob
77             3: initial hidden state [optional]
78             4: initial cell state [optional]
79     '''
80
81     enabled = True
82
83
84     def run_after(self):
85         return [
86             DecomposeBiLSTM
87         ]
88
89
90     def pattern(self):
91         return dict(
92             nodes=[
93                 ('lstm', dict(kind='op', op='LSTMSequence', format='onnx')),
94                 ('input', dict(kind='data')),
95                 ('W', dict(kind='data')),
96                 ('R', dict(kind='data')),
97             ],
98             edges=[
99                 ('input', 'lstm', {'in': 0}),
100                 ('W', 'lstm', {'bin': 'W'}),
101                 ('R', 'lstm', {'bin': 'R'}),
102             ]
103         )
104
105
106     def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
107         self.repack_weights(graph, match)
108         if match['lstm'].has_num_directions:
109             self.squeeze_num_directions(graph, match)
110         self.batch_sequence_transpose(graph, match)
111         self.check_not_supported_ports(graph, match)
112         self.states_squeeze(graph, match)
113
114
115     def repack_weights(self, graph: nx.MultiDiGraph, match: dict):
116
117         lstm = match['lstm']
118         W = match['W'].value.copy()
119         R = match['R'].value.copy()
120
121         # bidirectional case should be processed separately before this transformation
122         if lstm.direction not in ['forward', 'reverse']:
123             raise Error('ONNX/LSTM operator with `forward` or `reverse` is supported only. '
124                 'Node {} has direction = {} which is not supported.'.format(lstm.name, lstm.direction))
125
126         graph.remove_edge(match['W'].id, lstm.id)
127         graph.remove_edge(match['R'].id, lstm.id)
128
129         # find optional 'B'
130         if 3 in lstm.in_nodes():
131             # TODO: check if 'bin': 'B' attribute is assigned to this edge
132             B = lstm.in_node(3).value.copy()
133             graph.remove_edge(lstm.in_node(3).id, lstm.id)
134         else:
135             B = np.full([1, lstm.hidden_size*8], 0, dtype=np.float32)
136
137         # Add extra dimensions for W, R and B for easier repacking
138
139         B = B.reshape([
140             1,  # 0: num of directions, limitation: should be 1
141             2,  # 1: two input parts of the matrix: W, R
142             4,  # 2: four output parts of the matrix for all gates in order: i, o, f, c
143             lstm.hidden_size,  # 3: output size per direction and gate
144             1,  # 4: fake dimension to match the input dimension in W and R for shorter code
145         ])
146
147         W, R = [x.reshape([
148                 1,  # 0: num of directions, limitation: should be 1
149                 1,  # 1: dummy dimension to be aligned with B
150                 4,  # 2: four output parts of the matrix for all gates in order: i, o, f, c
151                 lstm.hidden_size,  # 3: output size per direction and gate
152                 -1])  # 4: input size/hidden size in W/R
153             for x in (W, R)]
154
155         input_size = match['input'].shape[2]
156         assert input_size == W.shape[-1]
157
158         WR = np.concatenate([W, R], axis=4)
159
160         # Reorder gates: iofc --> fico
161         gate_reorder = [2, 0, 3, 1]
162         WR = np.take(WR, gate_reorder, axis=2)
163         B = np.take(B, gate_reorder, axis=2)
164
165         # Sum component of B that correspond to W and R
166         B = np.add.reduce(B, axis=1, keepdims=True)
167
168         # Reorder dimensions by collection output dimensions first, then input dimension
169         # Interpret the numbers below by looking at W, R and B reshape above in the code
170         inout_reorder = [0, 2, 3, 1, 4]
171         WR = WR.transpose(inout_reorder)
172         B = B.transpose(inout_reorder)
173
174         # Supposing it is unidirectional LSTM, squeeze 'direction' dimension
175         assert WR.shape[0] == 1
176         assert B.shape[0] == 1
177         WR = WR.squeeze(axis=0)
178         B = B.squeeze(axis=0)
179
180         # Flatten all output (0, 1) and input dimensions (2, 3)
181         final_shape = [WR.shape[0] * WR.shape[1], -1]
182         WR = WR.reshape(final_shape)
183         B = B.reshape(final_shape)
184
185         # Squeeze fake dimension in B
186         B = B.squeeze(axis=-1)
187
188         assert WR.ndim == 2
189         assert B.ndim == 1
190         assert WR.shape[0] == lstm.hidden_size*4
191         assert B.shape[0] == lstm.hidden_size*4
192         assert WR.shape[1] == lstm.hidden_size + input_size
193
194         for blob, port, name in [(WR, 1, 'weights'), (B, 2, 'biases')]:
195             Op.create_and_connect_input_data_node(
196                 graph,
197                 lstm,
198                 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
199                 {'in': port, 'bin': name, 'permutation': None}
200             )
201
202
203     def squeeze_num_directions(self, graph: nx.MultiDiGraph, match: dict):
204         """ Assuming considered LSTM node has num_directions in output shape, remove it. """
205         lstm = match['lstm']
206         # num_directions is at 1st position in output shape, please refer to LSTMSequence op definition
207
208         direction_dim = [1, 0, 0] # index of dimension with direction index
209         for i in lstm.out_nodes():
210             old_data_node = lstm.out_node(i)
211             old_shape = old_data_node.shape.copy()
212             new_shape = np.delete(old_shape, direction_dim[i])
213             data = Op._create_data_node(graph, name=lstm.name + '/Out/{}/'.format(i), attrs={'shape': new_shape})
214             graph.remove_edge(lstm.id, old_data_node.id)
215             graph.add_edge(lstm.id, data.id, key=0, out=i)
216             reshape = Reshape(graph, dict(dim=old_shape))
217             reshape.create_node_with_data([data], dict(name=lstm.name + '/SqueezeNumDirections/{}'.format(i)), data_nodes=[old_data_node])
218
219
220     def batch_sequence_transpose(self, graph: nx.MultiDiGraph, match: dict):
221
222         lstm = match['lstm']
223         inp = match['input']
224         out = lstm.out_node(0)
225
226         if lstm.batch_dim == 0:
227             assert lstm.sequence_dim == 1
228             # nothing to do -- it's already in normal form
229             return
230
231         assert lstm.sequence_dim == 0
232         assert lstm.batch_dim == 1
233         assert len(inp.shape) == 3
234
235         # Reorder the first two dimensions on both ends: input and output.
236         # Two Permute ops are inserted before and after the LSTM node.
237         # In this transformation we don't analyze the rest of the model around
238         # LSTM cell, so these Permute ops are not fused to some other layers here.
239         # But other transformations in the pipeline may optimize the Permute ops out.
240
241         lstm.batch_dim, lstm.sequence_dim = lstm.sequence_dim, lstm.batch_dim
242         permute_before_and_after(inp, lstm, out, [1, 0, 2])
243
244
245     def check_not_supported_ports(self, graph: nx.MultiDiGraph, match: dict):
246         lstm = match['lstm']
247         inputs = lstm.in_edges()
248         assert 0 in inputs
249         assert 1 in inputs and inputs[1]['bin'] == 'weights'
250         assert 2 in inputs and inputs[2]['bin'] == 'biases'
251         assert 3 not in inputs
252         
253         if not(set(list(inputs.keys())) <= set([0, 1, 2, 5, 6])):
254             raise Error('Node {} that is interpreted as {} operation has '
255                 'some unexpected inputs initialized, '
256                 'they can include: sequence_lenght, '
257                 'and weight tensor for peepholes. '
258                 'This is not supported.'.format(lstm.name, lstm.op))
259
260
261     def states_squeeze(self, graph: nx.MultiDiGraph, match: dict):
262
263         lstm = match['lstm']
264
265         reshape = Reshape(graph, dict(dim=[lstm.in_node(0).shape[0], lstm.hidden_size]))
266
267         if len(lstm.in_nodes()) > 3:
268             init_h = lstm.in_node(5)
269             edge_attrs = deepcopy(graph.get_edge_data(init_h.id, lstm.id)[0])
270             edge_attrs['in'] = 3
271             graph.remove_edge(init_h.id, lstm.id)
272             new_init_h = reshape.create_node_with_data([init_h], dict(name=lstm.name + '/HiddenStateResize'))
273             graph.add_edge(new_init_h.id, lstm.id, **edge_attrs)
274
275         if len(lstm.in_nodes()) > 4:
276             init_c = lstm.in_node(6)
277             edge_attrs = deepcopy(graph.get_edge_data(init_c.id, lstm.id)[0])
278             edge_attrs['in'] = 4
279             graph.remove_edge(init_c.id, lstm.id)
280             new_init_c = reshape.create_node_with_data([init_c], dict(name=lstm.name + '/CellStateResize'))
281             graph.add_edge(new_init_c.id, lstm.id, **edge_attrs)