Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / MXNetRNNSequenceNormalize.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from mo.graph.graph import Graph
20 from mo.middle.replacement import MiddleReplacementPattern
21 from mo.ops.op import Op
22 from mo.ops.permute import Permute
23 from mo.ops.reshape import Reshape
24
25
26 class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
27     """
28         Convert blobs and shapes of MXNet-like RNN cell to IE compatible form.
29
30         The target form of this operation is not normally covered by a dedicated
31         layer in IE. It should be further transformed to some other layer
32         that are supported by IE. This transformation pass involves weights and
33         shapes processing only.
34
35         Post-conditions:
36         Inputs:
37             0: X input data,    shape [batch_size, seq_len, input_size] (or [seq_len. bathc_size, int_size], depends on
38                                 batch_dim param)
39             1: W weights blob,  shape [num_dir, n_cells, M, hidden_size, input_size]
40             2: R weights blob,  shape [num_dir, n_cells, M, hidden_size, hidden_size]
41             3: B biases blob,   shape [num_dir, n_cells, 2, M, hidden_size]
42             4: (optional) sequence_length, shape [batch_size]
43             5: initial hidden state, shape  [num_dir, batch_size, hidden_size]
44                                                       ([num_dir, n_cells, batch_size, hidden_size] if num_cells != 1)
45             6: (only for LSTM) initial cell state, shape [num_dir, batch_size, hidden_size]
46             7: (optional for LSTM) Peepholes weights, shape  [num_dir, n_cells, (M - 1) * hidden_size]
47
48         Outputs:
49             0: Y output blob,  shape [batch_size, num_dir, seq_len, hidden_size]
50             1: (optional) Y_h, shape [num_dir, batch_size, hidden_size]
51             2: (optional for LSTM) Y_c, shape [num_dir, batch_size, hidden_size]
52
53         Where:
54             M -- number of gates in this cell (4 for LSTM, 3 for GRU, 1 for RNN).
55             num_dir -- number of directions ('forvard', 'bidirectional', 'reverse')
56             n_cells -- number of cells in layer (always 1 for ONNX).
57
58     """
59     enabled = True
60
61     def run_after(self):
62         from extensions.middle.MXNetSplitMultiLayers import MXNetSplitLayersToRNNSequence
63         return [MXNetSplitLayersToRNNSequence]
64
65     def pattern(self):
66         return dict(
67             nodes=[
68                 ('rnn_layer', dict(kind='op', type='RNNSequence', format='mxnet')),
69                 ('input', dict(kind='data')),
70                 ('params', dict(kind='data')),
71             ],
72             edges=[
73                 ('input', 'rnn_layer', {'in': 0}),
74                 ('params', 'rnn_layer', {'in': 1}),
75             ]
76         )
77
78     def replace_pattern(self, graph: Graph, match: dict):
79         rnn_layer = match['rnn_layer']
80
81         self.check_init_states(graph, match)
82         self.repack_weights(graph, match)
83         self.add_output_reshape(graph, match)
84         self.check_input_ports(graph, match)
85         rnn_layer['normalized'] = True
86
87     @staticmethod
88     def repack_weights(graph: Graph, match: dict):
89         input = match['input']
90         rnn_layer = match['rnn_layer']
91         params = match['params'].value.copy()
92
93         graph.remove_edge(match['params'].id, rnn_layer.id)
94
95         input_size = input.shape[2]
96         direction = 2 if rnn_layer.has_num_directions else 1
97         bsize = (2 * rnn_layer.hidden_size * direction * 1) * rnn_layer.multiplier
98
99         W = np.array(params[0:len(params) - bsize])
100         B = np.array(params[len(params) - bsize:])
101
102         W = W.reshape((direction, -1))
103         B = B.reshape((direction, -1))
104
105         W, R = np.array(W[:, 0:rnn_layer.hidden_size * rnn_layer.multiplier * input_size]), np.array(W[:, rnn_layer.hidden_size * rnn_layer.multiplier* input_size:])
106
107         W, R = [x.reshape([
108             direction,  # 0: num of directions
109             1,  # 1: num_cells
110             rnn_layer.multiplier,  # 2: four output parts of the matrix for all gates
111             rnn_layer.hidden_size,  # 3: output size per direction and gate
112             -1])  # 4: input size/hidden size in W/R correspondingly
113             for x in (W, R)]
114
115         assert W.shape[-1] == input_size
116         assert R.shape[-1] == rnn_layer.hidden_size
117
118         B = B.reshape([
119                  direction,  # 0: num of directions, limitation: should be 1
120                  1,
121                  2,  # 3: num of component B
122                  rnn_layer.multiplier,  # 1: four output parts of the matrix for all gates in order: i, f, c, o
123                  rnn_layer.hidden_size,  # 2: output size per direction and gate
124         ])
125
126         # Reorder gates: ifco --> fico
127         gate_reorder = rnn_layer.gate_order
128         W = np.take(W, gate_reorder, axis=2)
129         R = np.take(R, gate_reorder, axis=2)
130         B = np.take(B, gate_reorder, axis=3)
131
132         for blob, port in [(W, 1), (R, 2), (B, 3)]:
133             Op.create_and_connect_input_data_node(
134                 graph,
135                 rnn_layer,
136                 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
137                 {'in': port, 'permutation': None}
138             )
139
140     @staticmethod
141     def check_init_states(graph: Graph, match: dict):
142         """
143         Check if cell have initial states and create zeros states if not.
144         And renumber ports for this states.
145         """
146         rnn_cell = match['rnn_layer']
147         num_directions = 2 if rnn_cell.direction == 'bidirectional' else 1
148         batch_size = rnn_cell.in_node(0).shape[rnn_cell.batch_dim]
149
150         h_init_port = 5
151         c_init_port = 6
152
153         if 2 not in rnn_cell.in_nodes():
154             h_shape = [num_directions, batch_size, rnn_cell.hidden_size]  # from ONNX spec
155             h_init = np.full(h_shape, 0, dtype=np.float32)
156             Op.create_and_connect_input_data_node(
157                 graph,
158                 rnn_cell,
159                 {'value': h_init, 'shape': np.array(h_init.shape, dtype=np.int64)},
160                 {'in': h_init_port, 'permutation': None}
161             )
162         else:
163             hidden_state_edge = graph.get_edge_data(rnn_cell.in_node(2).id, rnn_cell.id)
164             hidden_state_edge[0]['in'] = h_init_port
165
166         if rnn_cell.op == 'LSTM':
167             if 3 not in rnn_cell.in_nodes():
168                 c_shape = [num_directions, batch_size, rnn_cell.hidden_size]  # from ONNX spec
169                 c_init = np.full(c_shape, 0, dtype=np.float32)
170                 Op.create_and_connect_input_data_node(
171                     graph,
172                     rnn_cell,
173                     {'value': c_init, 'shape': np.array(c_init.shape, dtype=np.int64)},
174                     {'in': c_init_port, 'permutation': None}
175                 )
176             else:
177                 cell_state_edge = graph.get_edge_data(rnn_cell.in_node(3).id, rnn_cell.id)
178                 cell_state_edge[0]['in'] = c_init_port
179
180     @staticmethod
181     def add_output_reshape(graph: Graph, match: dict):
182         """
183         Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape
184         from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format.
185         """
186         lstm = match['rnn_layer']
187         input = match['input']
188         if not lstm.has_num_directions:
189             return
190         old_data_node =lstm.out_node(0)
191         num_directions = 2 if lstm.direction in ['bidirectional'] else 1
192         mxnet_shape = lstm.out_node(0).shape.copy()
193
194         if lstm.batch_dim == 0:
195             mo_shape = np.array([input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim], lstm.hidden_size],
196                              dtype=np.int64)
197         else:
198             mo_shape = np.array([input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim], lstm.hidden_size],
199                                 dtype=np.int64)
200
201         if lstm.has_num_directions:
202             mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))
203
204         new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
205         graph.remove_edge(lstm.id, old_data_node.id)
206         graph.add_edge(lstm.id, new_data.id, key=0, out=0)
207
208         # Add Permute
209         permute_order = np.array([0, 2, 1, 3], dtype=np.int64)
210         permute = Permute(graph, dict(order=permute_order))
211         permute_data = permute.create_node_with_data([new_data], dict(name=lstm.name + '/Permute_mxnet/'))
212
213         # Add Reshape
214         reshape = Reshape(graph, dict(dim=mxnet_shape))
215         reshape.create_node_with_data([permute_data], dict(name=lstm.name + '/Reshape_mxnet/'),
216                                       data_nodes=[old_data_node])
217
218     @staticmethod
219     def check_input_ports(graph: Graph, match: dict):
220         """
221         Check that all mandatory ports is present.
222         """
223         rnn_layer = match['rnn_layer']
224         mandatory_ports = [0, 1, 2, 3, 5]
225
226         if rnn_layer.op == 'LSTM':
227             mandatory_ports.append(6)
228
229         assert set(rnn_layer.in_nodes().keys()) >= set(mandatory_ports)