Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ONNXRNNSequenceNormalize.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from copy import deepcopy
18
19 import numpy as np
20
21 from mo.graph.graph import Node, Graph
22 from mo.middle.replacement import MiddleReplacementPattern
23 from mo.ops.op import Op
24 from mo.ops.permute import Permute
25
26
27 def permute_before_and_after(inp: Node, middle: Node, out: Node, input_order, output_order):
28     """
29         Insert two permutes: before middle node and after middle node.
30
31         Both permutes has a given order (input/output).
32     """
33     # Permute before input
34     permute = Permute(middle.graph, dict(order=np.array(input_order)))
35
36     edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0])
37     middle.graph.remove_edge(inp.id, middle.id)
38     new_inp = permute.create_node_with_data([inp], dict(name=middle.name + '/InputPermute'))
39     middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs)
40
41     # Permute after output
42     permute = Permute(middle.graph, dict(order=output_order))
43
44     middle.graph.remove_edge(middle.id, out.id)
45     new_out = Op._create_data_node(middle.graph, name=middle.name + '/WithoutPermute',
46                                    attrs={'shape': out.shape[output_order]})
47     middle.graph.add_edge(middle.id, new_out.id, key=0, out=0)
48     permute.create_node_with_data([new_out], dict(name=middle.name + '/OutputPermute'), data_nodes=out)
49
50
51 class ONNXRNNSequenceNormalize(MiddleReplacementPattern):
52     """
53         Convert blobs and shapes of ONNX-like LSTM, GRU, RNN cells to common form (internal for MO).
54         After this normalization pass passes for spliting bidirectional calls and
55         multilayer cells will be applied.
56
57         This transformation pass involves weights and shapes processing only:
58             1. Weights reshaping and reordering
59             2. Gates reordering
60
61
62         Inputs will have the following order after normalising:
63             0: X input data,    shape [batch_size, seq_len, input_size]
64             1: W weights blob,  shape [num_dir, n_cells, M, hidden_size, input_size]
65             2: R weights blob,  shape [num_dir, n_cells, M, hidden_size, hidden_size]
66             3: B biases blob,   shape [num_dir, n_cells, 2, M, hidden_size]
67             4: (optional) sequence_length, shape [batch_size]
68             5: initial hidden state, shape  [num_dir, batch_size, hidden_size]
69                                                       ([num_dir, n_cells, batch_size, hidden_size] if num_cells != 1)
70             6: (only for LSTM) initial cell state, shape [num_dir, batch_size, hidden_size]
71             7: (optional for LSTM) Peepholes weights, shape  [num_dir, n_cells, (M - 1) * hidden_size]
72
73         Outputs:
74             0: Y output blob,  shape [batch_size, num_dir, seq_len, hidden_size]
75             1: (optional) Y_h, shape [num_dir, batch_size, hidden_size]
76             2: (optional for LSTM) Y_c, shape [num_dir, batch_size, hidden_size]
77
78         Where:
79             M -- number of gates in this cell (4 for LSTM, 3 for GRU, 1 for RNN).
80             num_dir -- number of directions ('forvard', 'bidirectional', 'reverse')
81             n_cells -- number of cells in layer (always 1 for ONNX).
82     """
83
84     enabled = True
85
86     def pattern(self):
87         return dict(
88             nodes=[
89                 ('rnn_layer', dict(kind='op', type='RNNSequence', format='onnx')),
90                 ('input', dict(kind='data')),
91                 ('W', dict(kind='data')),
92                 ('R', dict(kind='data')),
93             ],
94             # We are not handling optional inputs
95             edges=[
96                 ('input', 'rnn_layer', {'in': 0}),
97                 ('W', 'rnn_layer', {'bin': 'W'}),
98                 ('R', 'rnn_layer', {'bin': 'R'}),
99             ]
100         )
101
102     def replace_pattern(self, graph: Graph, match: dict):
103         self.repack_weights(graph, match)
104         self.check_init_states(graph, match)
105         self.check_input_ports(graph, match)
106         match['rnn_layer']['normalized'] = True
107
108     @staticmethod
109     def repack_weights(graph: Graph, match: dict):
110         """
111         Repack weights into general format (described above) and reorder gates.
112         """
113         rnn_layer = match['rnn_layer']
114         W = match['W'].value.copy()
115         R = match['R'].value.copy()
116         num_directions = 2 if rnn_layer.direction == 'bidirectional' else 1
117
118         graph.remove_edge(match['W'].id, rnn_layer.id)
119         graph.remove_edge(match['R'].id, rnn_layer.id)
120
121         # find optional 'B' biases blob
122         if 3 in rnn_layer.in_nodes():
123             # TODO: check if 'bin': 'B' attribute is assigned to this edge
124             B = rnn_layer.in_node(3).value.copy()
125             graph.remove_edge(rnn_layer.in_node(3).id, rnn_layer.id)
126         else:
127             B_shape = [num_directions, 2 * rnn_layer.multiplier * rnn_layer.hidden_size]  # from ONNX spec
128             B = np.full(B_shape, 0, dtype=np.float32)
129
130         # Add extra dimensions for W, R and B for easier repacking and reordering
131         B = B.reshape([
132             num_directions,  # 0: num of directions
133             rnn_layer.num_layers,  # 1: num_layers
134             2,  # 2: two input parts of the matrix: W, R
135             rnn_layer.multiplier,  # 3: four output parts of the matrix for all gates in order: i, o, f, c
136             rnn_layer.hidden_size,  # 4: output size per direction and gate
137         ])
138
139         W, R = [x.reshape([
140                 num_directions,  # 0: num of directions
141                 rnn_layer.num_layers,  # 1: num_layers
142                 rnn_layer.multiplier,  # 2: four output parts of the matrix for all gates in order: i, o, f, c
143                 rnn_layer.hidden_size,  # 3: output size per direction and gate
144                 -1])  # 4: input size/hidden size in W/R correspondingly
145                 for x in (W, R)]
146
147         input_size = match['input'].shape[2]
148         assert input_size == W.shape[-1]
149
150         # Reorder gates: iofc --> fico
151         gate_reorder = rnn_layer.gate_order
152         W, R = (np.take(x, gate_reorder, axis=2) for x in (W, R))
153         B = np.take(B, gate_reorder, axis=3)
154
155         for blob, port in [(W, 1), (R, 2), (B, 3)]:
156             Op.create_and_connect_input_data_node(
157                 graph,
158                 rnn_layer,
159                 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
160                 {'in': port, 'permutation': None}
161             )
162
163     @staticmethod
164     def batch_sequence_transpose(graph: Graph, match: dict):
165         """
166
167         """
168         rnn_layer = match['rnn_layer']
169         inp = match['input']
170         out = rnn_layer.out_node(0)
171
172         if rnn_layer.batch_dim == 0:
173             assert rnn_layer.sequence_dim == 1
174             # nothing to do -- it's already in normal form
175             return
176
177         assert rnn_layer.sequence_dim == 0
178         assert rnn_layer.batch_dim == 1
179         assert len(inp.shape) == 3
180
181         # Reorder the first two dimensions on both ends: input and output.
182         # Two Permute ops are inserted before and after the LSTM node.
183         # In this transformation we don't analyze the rest of the model around
184         # LSTM cell, so these Permute ops are not fused to some other layers here.
185         # But other transformations in the pipeline may optimize the Permute ops out.
186
187         rnn_layer.batch_dim, rnn_layer.sequence_dim = rnn_layer.sequence_dim, rnn_layer.batch_dim
188         permute_before_and_after(inp, rnn_layer, out, [1, 0, 2], [2, 1, 0, 3])
189
190     @staticmethod
191     def check_init_states(graph: Graph, match: dict):
192         """
193         Check if cell have initial states and create zeros states if not.
194         """
195         rnn_layer = match['rnn_layer']
196         num_directions = 2 if rnn_layer.direction == 'bidirectional' else 1
197         batch_size = rnn_layer.in_node(0).shape[rnn_layer.batch_dim]
198
199         h_init_port = 5
200         c_init_port = 6
201
202         if h_init_port not in rnn_layer.in_nodes():
203             h_shape = [num_directions, batch_size, rnn_layer.hidden_size]  # from ONNX spec
204             h_init = np.full(h_shape, 0, dtype=np.float32)
205             Op.create_and_connect_input_data_node(
206                 graph,
207                 rnn_layer,
208                 {'value': h_init, 'shape': np.array(h_init.shape, dtype=np.int64)},
209                 {'in': h_init_port, 'permutation': None}
210             )
211
212         if rnn_layer.op == 'LSTM':
213             if c_init_port not in rnn_layer.in_nodes():
214                 c_shape = [num_directions, batch_size, rnn_layer.hidden_size]  # from ONNX spec
215                 c_init = np.full(c_shape, 0, dtype=np.float32)
216                 Op.create_and_connect_input_data_node(
217                     graph,
218                     rnn_layer,
219                     {'value': c_init, 'shape': np.array(c_init.shape, dtype=np.int64)},
220                     {'in': c_init_port, 'permutation': None}
221                 )
222
223     @staticmethod
224     def check_input_ports(graph: Graph, match: dict):
225         """
226         Check that all mandatory ports is present.
227         """
228         rnn_layer = match['rnn_layer']
229         mandatory_ports = [0, 1, 2, 3, 5]
230
231         if rnn_layer.op == 'LSTM':
232             mandatory_ports.extend([6])
233
234         assert set(rnn_layer.in_nodes().keys()) >= set(mandatory_ports)