Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / RNNSequenceNormalizeToIE.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 from copy import deepcopy
17
18 import numpy as np
19
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
22 from mo.ops.op import Op
23 from mo.ops.reshape import Reshape
24
25
26 class RNNSequenceNormalize(MiddleReplacementPattern):
27     """
28     This class normalize RNNSequence layers to IE-compatible from of weights, inputs and outputs.
29
30     In this pass next will be done:
31         1. Weights repack (squeeze all useless shapes in all blobls and concatenate W and R together, also add
32                             bin param and all similar staff )
33         1. UNSqueeze num directions (in states and )
34         2. Initial states squeeze
35         4. Renumbering inputs
36         5. Ports checks
37
38     After this normalization this layer will have next format of inputs:
39             0: X input data,    shape [batch_size, seq_len, input_size]
40             1: WR weights blob,  shape [M * hidden_size, hidden_size + input_size]
41             2: B biases blob,   shape [M * hidden_size]
42             3: (optional) sequence_length, shape [batch_size]
43             4: initial hidden state, shape  [batch_size, hidden_size]
44             5: (only for LSTM) initial cell state, shape [batch_size, hidden_size]
45             6: (optional for LSTM) Peepholes weights, shape  [(M - 1) * hidden_size]
46
47     """
48     def run_after(self):
49         from extensions.middle.DecomposeBidirectionalRNNSequence import DecomposeBidirectionalRNNSequence
50         return [DecomposeBidirectionalRNNSequence]
51
52     def pattern(self):
53         return dict(
54             nodes=[
55                 ('rnn_layer', dict(kind='op', type='RNNSequence')),
56                 ('input', dict(kind='data')),
57                 ('W', dict(kind='data')),
58                 ('R', dict(kind='data')),
59                 ('B', dict(kind='data')),
60             ],
61             edges=[
62                 ('input', 'rnn_layer', {'in': 0}),
63                 ('W', 'rnn_layer', {'in': 1}),
64                 ('R', 'rnn_layer', {'in': 2}),
65                 ('B', 'rnn_layer', {'in': 3}),
66             ],
67         )
68
69     def replace_pattern(self, graph: Graph, match: dict):
70         self.repack_weights(graph, match)
71         if match['rnn_layer'].has_num_directions:
72             self.unsqueeze_num_directions(graph, match)
73         self.squeeze_initial_states(graph, match)
74         self.reordering_inputs(graph, match)
75         # some additional checks for ports number and similar stuff
76
77     def repack_weights(self, graph: Graph, match: dict):
78         # Concat W, R in IE- format
79         # Delete useless num_dir dimensions and n_cells dimensions in W, R, B (peepholes?)
80         lstm = match['rnn_layer']
81         W, R, B = match['W'].value.copy(), match['R'].value.copy(), match['B'].value.copy()
82
83         graph.remove_edge(match['W'].id, lstm.id)
84         graph.remove_edge(match['R'].id, lstm.id)
85         graph.remove_edge(match['B'].id, lstm.id)
86
87         # Sum component of B that correspond to W and R
88         if lstm.op == 'GRU' and lstm.linear_before_reset:
89             B_shape = np.array(B.shape)
90             B_shape[3] = 4
91             B_shape[2] = 1
92             B_tmp = np.zeros(shape=B_shape)
93             B_tmp[:, :, :, 0, :] = B[:, :, 0, 0, :] + B[:, :, 1, 0, :]
94             B_tmp[:, :, :, 1, :] = B[:, :, 0, 1, :] + B[:, :, 1, 1, :]
95             B_tmp[:, :, :, 2, :] = B[:, :, 0, 2, :][:, :, np.newaxis, :]
96             B_tmp[:, :, :, 3, :] = B[:, :, 1, 2, :][:, :, np.newaxis, :]
97             B = B_tmp
98         else:
99             B = np.add.reduce(B, axis=2, keepdims=True)
100
101         # Concatenate W, R to IE-compatible format
102         assert len(W.shape) == 5
103         assert len(R.shape) == 5
104         WR = np.concatenate([W, R], axis=4)
105
106         # Squeeze useless dimensions
107         assert WR.shape[0] == 1  # num_dir == 1
108         assert WR.shape[1] == 1  # num_cells == 1
109         assert B.shape[0] == 1
110         assert B.shape[1] == 1
111         WR = WR.squeeze(axis=(0, 1))
112         B = B.squeeze(axis=(0, 1))
113
114         # Flatten all output (0, 1) and input dimensions (2, 3)
115         final_shape_WR = [WR.shape[0] * WR.shape[1], -1]
116         assert final_shape_WR[0] == lstm.hidden_size * lstm.multiplier
117         WR = WR.reshape(final_shape_WR)
118
119         final_shape_B = final_shape_WR
120         if lstm.op == 'GRU' and lstm.linear_before_reset:
121             final_shape_B[0] = lstm.hidden_size * 4
122         B = B.reshape(final_shape_B)
123
124         # Squeeze fake dimension in B
125         B = B.squeeze(axis=-1)
126
127         for blob, port, name in [(WR, 1, 'weights'), (B, 2, 'biases')]:
128             Op.create_and_connect_input_data_node(
129                 graph,
130                 lstm,
131                 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
132                 {'in': port, 'bin': name, 'permutation': None}
133             )
134
135     @staticmethod
136     def unsqueeze_num_directions(graph: Graph, match: dict):
137         """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Reshape
138             to match it.
139         """
140
141         rnn_layer = match['rnn_layer']
142         # num_directions is at 1st position in output shape, and in 0st position in hidden and cell states
143         # please refer to docs in this transform
144
145         direction_dim = [1, 0, 0]  # index of dimension with direction index
146         for i in rnn_layer.out_nodes():
147             old_data_node = rnn_layer.out_node(i)
148             old_shape = old_data_node.shape.copy()
149             new_shape = np.delete(old_shape, direction_dim[i])
150
151             data = Op._create_data_node(graph, name=rnn_layer.name + '/Out/{}/'.format(i), attrs={'shape': new_shape})
152             graph.remove_edge(rnn_layer.id, old_data_node.id)
153             graph.add_edge(rnn_layer.id, data.id, key=0, out=i)
154
155             reshape = Reshape(graph, dict(dim=old_shape))
156             reshape.create_node_with_data([data], dict(name=rnn_layer.name + '/SqueezeNumDirections/{}'.format(i)),
157                                           data_nodes=[old_data_node])
158
159     @staticmethod
160     def squeeze_initial_states(graph: Graph, match: dict):
161         """
162         Squeeze input initial states of recurrent node to 2-D shape.
163         """
164         hidden_init_port = 5
165         cell_init_port = 6
166
167         rnn_layer = match['rnn_layer']
168
169         reshape = Reshape(graph, dict(dim=[rnn_layer.in_node(0).shape[rnn_layer.batch_dim], rnn_layer.hidden_size]))
170
171         assert hidden_init_port in rnn_layer.in_nodes()
172         init_h = rnn_layer.in_node(hidden_init_port)
173         edge_attrs = deepcopy(graph.get_edge_data(init_h.id, rnn_layer.id)[0])
174         edge_attrs['in'] = hidden_init_port
175         graph.remove_edge(init_h.id, rnn_layer.id)
176         new_init_h = reshape.create_node_with_data([init_h], dict(name=rnn_layer.name + '/HiddenStateResize'))
177         graph.add_edge(new_init_h.id, rnn_layer.id, **edge_attrs)
178
179         if rnn_layer.op == 'LSTM':
180             assert cell_init_port in rnn_layer.in_nodes()
181
182             init_c = rnn_layer.in_node(cell_init_port)
183             edge_attrs = deepcopy(graph.get_edge_data(init_c.id, rnn_layer.id)[0])
184             edge_attrs['in'] = cell_init_port
185             graph.remove_edge(init_c.id, rnn_layer.id)
186             new_init_c = reshape.create_node_with_data([init_c], dict(name=rnn_layer.name + '/CellStateResize'))
187             graph.add_edge(new_init_c.id, rnn_layer.id, **edge_attrs)
188
189     @staticmethod
190     def reordering_inputs(graph: Graph, match: dict):
191         """
192         Reorder (renumbering) inputs to described format. We need to renumber initial states ports.
193         """
194         rnn_layer = match['rnn_layer']
195         assert 5 in rnn_layer.in_nodes()
196         hidden_state_edge = graph.get_edge_data(rnn_layer.in_node(5).id, rnn_layer.id)
197         hidden_state_edge[0]['in'] = 4
198
199         if rnn_layer.op == 'LSTM':
200             assert 6 in rnn_layer.in_nodes()
201             cell_state_edge = graph.get_edge_data(rnn_layer.in_node(6).id, rnn_layer.id)
202             cell_state_edge[0]['in'] = 5
203
204     @staticmethod
205     def ports_checks(graph: Graph, match: dict):
206         """
207             Check that all mandatory ports is present.
208         """
209         rnn_layer = match['rnn_layer']
210         mandatory_ports = [0, 1, 2, 4]
211
212         if rnn_layer.op == 'LSTM':
213             mandatory_ports.append(5)
214
215         assert set(rnn_layer.in_nodes().keys()) >= set(mandatory_ports)