2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from copy import deepcopy
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
22 from mo.ops.op import Op
23 from mo.ops.reshape import Reshape
26 class RNNSequenceNormalize(MiddleReplacementPattern):
28 This class normalize RNNSequence layers to IE-compatible from of weights, inputs and outputs.
30 In this pass next will be done:
31 1. Weights repack (squeeze all useless shapes in all blobls and concatenate W and R together, also add
32 bin param and all similar staff )
33 1. UNSqueeze num directions (in states and )
34 2. Initial states squeeze
38 After this normalization this layer will have next format of inputs:
39 0: X input data, shape [batch_size, seq_len, input_size]
40 1: WR weights blob, shape [M * hidden_size, hidden_size + input_size]
41 2: B biases blob, shape [M * hidden_size]
42 3: (optional) sequence_length, shape [batch_size]
43 4: initial hidden state, shape [batch_size, hidden_size]
44 5: (only for LSTM) initial cell state, shape [batch_size, hidden_size]
45 6: (optional for LSTM) Peepholes weights, shape [(M - 1) * hidden_size]
49 from extensions.middle.DecomposeBidirectionalRNNSequence import DecomposeBidirectionalRNNSequence
50 return [DecomposeBidirectionalRNNSequence]
55 ('rnn_layer', dict(kind='op', type='RNNSequence')),
56 ('input', dict(kind='data')),
57 ('W', dict(kind='data')),
58 ('R', dict(kind='data')),
59 ('B', dict(kind='data')),
62 ('input', 'rnn_layer', {'in': 0}),
63 ('W', 'rnn_layer', {'in': 1}),
64 ('R', 'rnn_layer', {'in': 2}),
65 ('B', 'rnn_layer', {'in': 3}),
69 def replace_pattern(self, graph: Graph, match: dict):
70 self.repack_weights(graph, match)
71 if match['rnn_layer'].has_num_directions:
72 self.unsqueeze_num_directions(graph, match)
73 self.squeeze_initial_states(graph, match)
74 self.reordering_inputs(graph, match)
75 # some additional checks for ports number and similar stuff
77 def repack_weights(self, graph: Graph, match: dict):
78 # Concat W, R in IE- format
79 # Delete useless num_dir dimensions and n_cells dimensions in W, R, B (peepholes?)
80 lstm = match['rnn_layer']
81 W, R, B = match['W'].value.copy(), match['R'].value.copy(), match['B'].value.copy()
83 graph.remove_edge(match['W'].id, lstm.id)
84 graph.remove_edge(match['R'].id, lstm.id)
85 graph.remove_edge(match['B'].id, lstm.id)
87 # Sum component of B that correspond to W and R
88 if lstm.op == 'GRU' and lstm.linear_before_reset:
89 B_shape = np.array(B.shape)
92 B_tmp = np.zeros(shape=B_shape)
93 B_tmp[:, :, :, 0, :] = B[:, :, 0, 0, :] + B[:, :, 1, 0, :]
94 B_tmp[:, :, :, 1, :] = B[:, :, 0, 1, :] + B[:, :, 1, 1, :]
95 B_tmp[:, :, :, 2, :] = B[:, :, 0, 2, :][:, :, np.newaxis, :]
96 B_tmp[:, :, :, 3, :] = B[:, :, 1, 2, :][:, :, np.newaxis, :]
99 B = np.add.reduce(B, axis=2, keepdims=True)
101 # Concatenate W, R to IE-compatible format
102 assert len(W.shape) == 5
103 assert len(R.shape) == 5
104 WR = np.concatenate([W, R], axis=4)
106 # Squeeze useless dimensions
107 assert WR.shape[0] == 1 # num_dir == 1
108 assert WR.shape[1] == 1 # num_cells == 1
109 assert B.shape[0] == 1
110 assert B.shape[1] == 1
111 WR = WR.squeeze(axis=(0, 1))
112 B = B.squeeze(axis=(0, 1))
114 # Flatten all output (0, 1) and input dimensions (2, 3)
115 final_shape_WR = [WR.shape[0] * WR.shape[1], -1]
116 assert final_shape_WR[0] == lstm.hidden_size * lstm.multiplier
117 WR = WR.reshape(final_shape_WR)
119 final_shape_B = final_shape_WR
120 if lstm.op == 'GRU' and lstm.linear_before_reset:
121 final_shape_B[0] = lstm.hidden_size * 4
122 B = B.reshape(final_shape_B)
124 # Squeeze fake dimension in B
125 B = B.squeeze(axis=-1)
127 for blob, port, name in [(WR, 1, 'weights'), (B, 2, 'biases')]:
128 Op.create_and_connect_input_data_node(
131 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
132 {'in': port, 'bin': name, 'permutation': None}
136 def unsqueeze_num_directions(graph: Graph, match: dict):
137 """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Reshape
141 rnn_layer = match['rnn_layer']
142 # num_directions is at 1st position in output shape, and in 0st position in hidden and cell states
143 # please refer to docs in this transform
145 direction_dim = [1, 0, 0] # index of dimension with direction index
146 for i in rnn_layer.out_nodes():
147 old_data_node = rnn_layer.out_node(i)
148 old_shape = old_data_node.shape.copy()
149 new_shape = np.delete(old_shape, direction_dim[i])
151 data = Op._create_data_node(graph, name=rnn_layer.name + '/Out/{}/'.format(i), attrs={'shape': new_shape})
152 graph.remove_edge(rnn_layer.id, old_data_node.id)
153 graph.add_edge(rnn_layer.id, data.id, key=0, out=i)
155 reshape = Reshape(graph, dict(dim=old_shape))
156 reshape.create_node_with_data([data], dict(name=rnn_layer.name + '/SqueezeNumDirections/{}'.format(i)),
157 data_nodes=[old_data_node])
160 def squeeze_initial_states(graph: Graph, match: dict):
162 Squeeze input initial states of recurrent node to 2-D shape.
167 rnn_layer = match['rnn_layer']
169 reshape = Reshape(graph, dict(dim=[rnn_layer.in_node(0).shape[rnn_layer.batch_dim], rnn_layer.hidden_size]))
171 assert hidden_init_port in rnn_layer.in_nodes()
172 init_h = rnn_layer.in_node(hidden_init_port)
173 edge_attrs = deepcopy(graph.get_edge_data(init_h.id, rnn_layer.id)[0])
174 edge_attrs['in'] = hidden_init_port
175 graph.remove_edge(init_h.id, rnn_layer.id)
176 new_init_h = reshape.create_node_with_data([init_h], dict(name=rnn_layer.name + '/HiddenStateResize'))
177 graph.add_edge(new_init_h.id, rnn_layer.id, **edge_attrs)
179 if rnn_layer.op == 'LSTM':
180 assert cell_init_port in rnn_layer.in_nodes()
182 init_c = rnn_layer.in_node(cell_init_port)
183 edge_attrs = deepcopy(graph.get_edge_data(init_c.id, rnn_layer.id)[0])
184 edge_attrs['in'] = cell_init_port
185 graph.remove_edge(init_c.id, rnn_layer.id)
186 new_init_c = reshape.create_node_with_data([init_c], dict(name=rnn_layer.name + '/CellStateResize'))
187 graph.add_edge(new_init_c.id, rnn_layer.id, **edge_attrs)
190 def reordering_inputs(graph: Graph, match: dict):
192 Reorder (renumbering) inputs to described format. We need to renumber initial states ports.
194 rnn_layer = match['rnn_layer']
195 assert 5 in rnn_layer.in_nodes()
196 hidden_state_edge = graph.get_edge_data(rnn_layer.in_node(5).id, rnn_layer.id)
197 hidden_state_edge[0]['in'] = 4
199 if rnn_layer.op == 'LSTM':
200 assert 6 in rnn_layer.in_nodes()
201 cell_state_edge = graph.get_edge_data(rnn_layer.in_node(6).id, rnn_layer.id)
202 cell_state_edge[0]['in'] = 5
205 def ports_checks(graph: Graph, match: dict):
207 Check that all mandatory ports is present.
209 rnn_layer = match['rnn_layer']
210 mandatory_ports = [0, 1, 2, 4]
212 if rnn_layer.op == 'LSTM':
213 mandatory_ports.append(5)
215 assert set(rnn_layer.in_nodes().keys()) >= set(mandatory_ports)