2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from copy import deepcopy
21 from mo.graph.graph import Node, Graph
22 from mo.middle.replacement import MiddleReplacementPattern
23 from mo.ops.op import Op
24 from mo.ops.permute import Permute
27 def permute_before_and_after(inp: Node, middle: Node, out: Node, input_order, output_order):
29 Insert two permutes: before middle node and after middle node.
31 Both permutes has a given order (input/output).
33 # Permute before input
34 permute = Permute(middle.graph, dict(order=np.array(input_order)))
36 edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0])
37 middle.graph.remove_edge(inp.id, middle.id)
38 new_inp = permute.create_node_with_data([inp], dict(name=middle.name + '/InputPermute'))
39 middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs)
41 # Permute after output
42 permute = Permute(middle.graph, dict(order=output_order))
44 middle.graph.remove_edge(middle.id, out.id)
45 new_out = Op._create_data_node(middle.graph, name=middle.name + '/WithoutPermute',
46 attrs={'shape': out.shape[output_order]})
47 middle.graph.add_edge(middle.id, new_out.id, key=0, out=0)
48 permute.create_node_with_data([new_out], dict(name=middle.name + '/OutputPermute'), data_nodes=out)
51 class ONNXRNNSequenceNormalize(MiddleReplacementPattern):
53 Convert blobs and shapes of ONNX-like LSTM, GRU, RNN cells to common form (internal for MO).
54 After this normalization pass passes for spliting bidirectional calls and
55 multilayer cells will be applied.
57 This transformation pass involves weights and shapes processing only:
58 1. Weights reshaping and reordering
62 Inputs will have the following order after normalising:
63 0: X input data, shape [batch_size, seq_len, input_size]
64 1: W weights blob, shape [num_dir, n_cells, M, hidden_size, input_size]
65 2: R weights blob, shape [num_dir, n_cells, M, hidden_size, hidden_size]
66 3: B biases blob, shape [num_dir, n_cells, 2, M, hidden_size]
67 4: (optional) sequence_length, shape [batch_size]
68 5: initial hidden state, shape [num_dir, batch_size, hidden_size]
69 ([num_dir, n_cells, batch_size, hidden_size] if num_cells != 1)
70 6: (only for LSTM) initial cell state, shape [num_dir, batch_size, hidden_size]
71 7: (optional for LSTM) Peepholes weights, shape [num_dir, n_cells, (M - 1) * hidden_size]
74 0: Y output blob, shape [batch_size, num_dir, seq_len, hidden_size]
75 1: (optional) Y_h, shape [num_dir, batch_size, hidden_size]
76 2: (optional for LSTM) Y_c, shape [num_dir, batch_size, hidden_size]
79 M -- number of gates in this cell (4 for LSTM, 3 for GRU, 1 for RNN).
80 num_dir -- number of directions ('forvard', 'bidirectional', 'reverse')
81 n_cells -- number of cells in layer (always 1 for ONNX).
89 ('rnn_layer', dict(kind='op', type='RNNSequence', format='onnx')),
90 ('input', dict(kind='data')),
91 ('W', dict(kind='data')),
92 ('R', dict(kind='data')),
94 # We are not handling optional inputs
96 ('input', 'rnn_layer', {'in': 0}),
97 ('W', 'rnn_layer', {'bin': 'W'}),
98 ('R', 'rnn_layer', {'bin': 'R'}),
102 def replace_pattern(self, graph: Graph, match: dict):
103 self.repack_weights(graph, match)
104 self.check_init_states(graph, match)
105 self.check_input_ports(graph, match)
106 match['rnn_layer']['normalized'] = True
109 def repack_weights(graph: Graph, match: dict):
111 Repack weights into general format (described above) and reorder gates.
113 rnn_layer = match['rnn_layer']
114 W = match['W'].value.copy()
115 R = match['R'].value.copy()
116 num_directions = 2 if rnn_layer.direction == 'bidirectional' else 1
118 graph.remove_edge(match['W'].id, rnn_layer.id)
119 graph.remove_edge(match['R'].id, rnn_layer.id)
121 # find optional 'B' biases blob
122 if 3 in rnn_layer.in_nodes():
123 # TODO: check if 'bin': 'B' attribute is assigned to this edge
124 B = rnn_layer.in_node(3).value.copy()
125 graph.remove_edge(rnn_layer.in_node(3).id, rnn_layer.id)
127 B_shape = [num_directions, 2 * rnn_layer.multiplier * rnn_layer.hidden_size] # from ONNX spec
128 B = np.full(B_shape, 0, dtype=np.float32)
130 # Add extra dimensions for W, R and B for easier repacking and reordering
132 num_directions, # 0: num of directions
133 rnn_layer.num_layers, # 1: num_layers
134 2, # 2: two input parts of the matrix: W, R
135 rnn_layer.multiplier, # 3: four output parts of the matrix for all gates in order: i, o, f, c
136 rnn_layer.hidden_size, # 4: output size per direction and gate
140 num_directions, # 0: num of directions
141 rnn_layer.num_layers, # 1: num_layers
142 rnn_layer.multiplier, # 2: four output parts of the matrix for all gates in order: i, o, f, c
143 rnn_layer.hidden_size, # 3: output size per direction and gate
144 -1]) # 4: input size/hidden size in W/R correspondingly
147 input_size = match['input'].shape[2]
148 assert input_size == W.shape[-1]
150 # Reorder gates: iofc --> fico
151 gate_reorder = rnn_layer.gate_order
152 W, R = (np.take(x, gate_reorder, axis=2) for x in (W, R))
153 B = np.take(B, gate_reorder, axis=3)
155 for blob, port in [(W, 1), (R, 2), (B, 3)]:
156 Op.create_and_connect_input_data_node(
159 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
160 {'in': port, 'permutation': None}
164 def batch_sequence_transpose(graph: Graph, match: dict):
168 rnn_layer = match['rnn_layer']
170 out = rnn_layer.out_node(0)
172 if rnn_layer.batch_dim == 0:
173 assert rnn_layer.sequence_dim == 1
174 # nothing to do -- it's already in normal form
177 assert rnn_layer.sequence_dim == 0
178 assert rnn_layer.batch_dim == 1
179 assert len(inp.shape) == 3
181 # Reorder the first two dimensions on both ends: input and output.
182 # Two Permute ops are inserted before and after the LSTM node.
183 # In this transformation we don't analyze the rest of the model around
184 # LSTM cell, so these Permute ops are not fused to some other layers here.
185 # But other transformations in the pipeline may optimize the Permute ops out.
187 rnn_layer.batch_dim, rnn_layer.sequence_dim = rnn_layer.sequence_dim, rnn_layer.batch_dim
188 permute_before_and_after(inp, rnn_layer, out, [1, 0, 2], [2, 1, 0, 3])
191 def check_init_states(graph: Graph, match: dict):
193 Check if cell have initial states and create zeros states if not.
195 rnn_layer = match['rnn_layer']
196 num_directions = 2 if rnn_layer.direction == 'bidirectional' else 1
197 batch_size = rnn_layer.in_node(0).shape[rnn_layer.batch_dim]
202 if h_init_port not in rnn_layer.in_nodes():
203 h_shape = [num_directions, batch_size, rnn_layer.hidden_size] # from ONNX spec
204 h_init = np.full(h_shape, 0, dtype=np.float32)
205 Op.create_and_connect_input_data_node(
208 {'value': h_init, 'shape': np.array(h_init.shape, dtype=np.int64)},
209 {'in': h_init_port, 'permutation': None}
212 if rnn_layer.op == 'LSTM':
213 if c_init_port not in rnn_layer.in_nodes():
214 c_shape = [num_directions, batch_size, rnn_layer.hidden_size] # from ONNX spec
215 c_init = np.full(c_shape, 0, dtype=np.float32)
216 Op.create_and_connect_input_data_node(
219 {'value': c_init, 'shape': np.array(c_init.shape, dtype=np.int64)},
220 {'in': c_init_port, 'permutation': None}
224 def check_input_ports(graph: Graph, match: dict):
226 Check that all mandatory ports is present.
228 rnn_layer = match['rnn_layer']
229 mandatory_ports = [0, 1, 2, 3, 5]
231 if rnn_layer.op == 'LSTM':
232 mandatory_ports.extend([6])
234 assert set(rnn_layer.in_nodes().keys()) >= set(mandatory_ports)