2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from mo.graph.graph import Graph
20 from mo.middle.replacement import MiddleReplacementPattern
21 from mo.ops.op import Op
22 from mo.ops.permute import Permute
23 from mo.ops.reshape import Reshape
26 class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
28 Convert blobs and shapes of MXNet-like RNN cell to IE compatible form.
30 The target form of this operation is not normally covered by a dedicated
31 layer in IE. It should be further transformed to some other layer
32 that are supported by IE. This transformation pass involves weights and
33 shapes processing only.
37 0: X input data, shape [batch_size, seq_len, input_size] (or [seq_len. bathc_size, int_size], depends on
39 1: W weights blob, shape [num_dir, n_cells, M, hidden_size, input_size]
40 2: R weights blob, shape [num_dir, n_cells, M, hidden_size, hidden_size]
41 3: B biases blob, shape [num_dir, n_cells, 2, M, hidden_size]
42 4: (optional) sequence_length, shape [batch_size]
43 5: initial hidden state, shape [num_dir, batch_size, hidden_size]
44 ([num_dir, n_cells, batch_size, hidden_size] if num_cells != 1)
45 6: (only for LSTM) initial cell state, shape [num_dir, batch_size, hidden_size]
46 7: (optional for LSTM) Peepholes weights, shape [num_dir, n_cells, (M - 1) * hidden_size]
49 0: Y output blob, shape [batch_size, num_dir, seq_len, hidden_size]
50 1: (optional) Y_h, shape [num_dir, batch_size, hidden_size]
51 2: (optional for LSTM) Y_c, shape [num_dir, batch_size, hidden_size]
54 M -- number of gates in this cell (4 for LSTM, 3 for GRU, 1 for RNN).
55 num_dir -- number of directions ('forvard', 'bidirectional', 'reverse')
56 n_cells -- number of cells in layer (always 1 for ONNX).
62 from extensions.middle.MXNetSplitMultiLayers import MXNetSplitLayersToRNNSequence
63 return [MXNetSplitLayersToRNNSequence]
68 ('rnn_layer', dict(kind='op', type='RNNSequence', format='mxnet')),
69 ('input', dict(kind='data')),
70 ('params', dict(kind='data')),
73 ('input', 'rnn_layer', {'in': 0}),
74 ('params', 'rnn_layer', {'in': 1}),
78 def replace_pattern(self, graph: Graph, match: dict):
79 rnn_layer = match['rnn_layer']
81 self.check_init_states(graph, match)
82 self.repack_weights(graph, match)
83 self.add_output_reshape(graph, match)
84 self.check_input_ports(graph, match)
85 rnn_layer['normalized'] = True
88 def repack_weights(graph: Graph, match: dict):
89 input = match['input']
90 rnn_layer = match['rnn_layer']
91 params = match['params'].value.copy()
93 graph.remove_edge(match['params'].id, rnn_layer.id)
95 input_size = input.shape[2]
96 direction = 2 if rnn_layer.has_num_directions else 1
97 bsize = (2 * rnn_layer.hidden_size * direction * 1) * rnn_layer.multiplier
99 W = np.array(params[0:len(params) - bsize])
100 B = np.array(params[len(params) - bsize:])
102 W = W.reshape((direction, -1))
103 B = B.reshape((direction, -1))
105 W, R = np.array(W[:, 0:rnn_layer.hidden_size * rnn_layer.multiplier * input_size]), np.array(W[:, rnn_layer.hidden_size * rnn_layer.multiplier* input_size:])
108 direction, # 0: num of directions
110 rnn_layer.multiplier, # 2: four output parts of the matrix for all gates
111 rnn_layer.hidden_size, # 3: output size per direction and gate
112 -1]) # 4: input size/hidden size in W/R correspondingly
115 assert W.shape[-1] == input_size
116 assert R.shape[-1] == rnn_layer.hidden_size
119 direction, # 0: num of directions, limitation: should be 1
121 2, # 3: num of component B
122 rnn_layer.multiplier, # 1: four output parts of the matrix for all gates in order: i, f, c, o
123 rnn_layer.hidden_size, # 2: output size per direction and gate
126 # Reorder gates: ifco --> fico
127 gate_reorder = rnn_layer.gate_order
128 W = np.take(W, gate_reorder, axis=2)
129 R = np.take(R, gate_reorder, axis=2)
130 B = np.take(B, gate_reorder, axis=3)
132 for blob, port in [(W, 1), (R, 2), (B, 3)]:
133 Op.create_and_connect_input_data_node(
136 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
137 {'in': port, 'permutation': None}
141 def check_init_states(graph: Graph, match: dict):
143 Check if cell have initial states and create zeros states if not.
144 And renumber ports for this states.
146 rnn_cell = match['rnn_layer']
147 num_directions = 2 if rnn_cell.direction == 'bidirectional' else 1
148 batch_size = rnn_cell.in_node(0).shape[rnn_cell.batch_dim]
153 if 2 not in rnn_cell.in_nodes():
154 h_shape = [num_directions, batch_size, rnn_cell.hidden_size] # from ONNX spec
155 h_init = np.full(h_shape, 0, dtype=np.float32)
156 Op.create_and_connect_input_data_node(
159 {'value': h_init, 'shape': np.array(h_init.shape, dtype=np.int64)},
160 {'in': h_init_port, 'permutation': None}
163 hidden_state_edge = graph.get_edge_data(rnn_cell.in_node(2).id, rnn_cell.id)
164 hidden_state_edge[0]['in'] = h_init_port
166 if rnn_cell.op == 'LSTM':
167 if 3 not in rnn_cell.in_nodes():
168 c_shape = [num_directions, batch_size, rnn_cell.hidden_size] # from ONNX spec
169 c_init = np.full(c_shape, 0, dtype=np.float32)
170 Op.create_and_connect_input_data_node(
173 {'value': c_init, 'shape': np.array(c_init.shape, dtype=np.int64)},
174 {'in': c_init_port, 'permutation': None}
177 cell_state_edge = graph.get_edge_data(rnn_cell.in_node(3).id, rnn_cell.id)
178 cell_state_edge[0]['in'] = c_init_port
181 def add_output_reshape(graph: Graph, match: dict):
183 Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape
184 from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format.
186 lstm = match['rnn_layer']
187 input = match['input']
188 if not lstm.has_num_directions:
190 old_data_node =lstm.out_node(0)
191 num_directions = 2 if lstm.direction in ['bidirectional'] else 1
192 mxnet_shape = lstm.out_node(0).shape.copy()
194 if lstm.batch_dim == 0:
195 mo_shape = np.array([input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim], lstm.hidden_size],
198 mo_shape = np.array([input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim], lstm.hidden_size],
201 if lstm.has_num_directions:
202 mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))
204 new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
205 graph.remove_edge(lstm.id, old_data_node.id)
206 graph.add_edge(lstm.id, new_data.id, key=0, out=0)
209 permute_order = np.array([0, 2, 1, 3], dtype=np.int64)
210 permute = Permute(graph, dict(order=permute_order))
211 permute_data = permute.create_node_with_data([new_data], dict(name=lstm.name + '/Permute_mxnet/'))
214 reshape = Reshape(graph, dict(dim=mxnet_shape))
215 reshape.create_node_with_data([permute_data], dict(name=lstm.name + '/Reshape_mxnet/'),
216 data_nodes=[old_data_node])
219 def check_input_ports(graph: Graph, match: dict):
221 Check that all mandatory ports is present.
223 rnn_layer = match['rnn_layer']
224 mandatory_ports = [0, 1, 2, 3, 5]
226 if rnn_layer.op == 'LSTM':
227 mandatory_ports.append(6)
229 assert set(rnn_layer.in_nodes().keys()) >= set(mandatory_ports)