2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from copy import deepcopy
21 from extensions.middle.decompose_bi_lstm import DecomposeBiLSTM
22 from mo.utils.error import Error
23 from mo.middle.replacement import MiddleReplacementPattern
24 from mo.ops.op import Op
25 from mo.ops.permute import Permute
26 from mo.ops.reshape import Reshape
27 from mo.graph.graph import Node
30 def inverse_perm(order: np.array):
31 indices = np.empty(order.size, dtype=np.int64)
32 indices[order] = np.arange(order.size)
36 def permute_before_and_after(inp: Node, middle: Node, out: Node, order):
37 ''' Insert two permutes: before middle node and after middle node.
39 The first permute has a given order, the second permute has an
43 permute = Permute(middle.graph, dict(order=np.array(order)))
45 edge_attrs = deepcopy(middle.graph.get_edge_data(inp.id, middle.id)[0])
46 middle.graph.remove_edge(inp.id, middle.id)
47 new_inp = permute.create_node_with_data([inp], dict(name=middle.name + '/InputPermute'))
48 middle.graph.add_edge(new_inp.id, middle.id, **edge_attrs)
50 permute = Permute(middle.graph, dict(order=inverse_perm(np.array(order))))
52 middle.graph.remove_edge(middle.id, out.id)
53 new_out = Op._create_data_node(middle.graph, name=middle.name + '/WithoutPermute', attrs={'shape': out.shape[order]})
54 middle.graph.add_edge(middle.id, new_out.id, key=0, out=0)
55 permute.create_node_with_data([new_out], dict(name=middle.name + '/OutputPermute'), data_nodes=out)
58 class LSTMSequenceNormalize(MiddleReplacementPattern):
59 ''' Convert blobs and shapes of ONNX-like LSTM to IE compatible form.
61 Fuse W, R and optional B input blobs to weights and biases according
62 to IE LSTM specification. In case of bidirectional LSTM, the resulting
63 blobs are not directly supported by IE, but it will be further processed
64 by a separate transformation to break down to one-directional LSTMs.
66 The target form of this operation is not normally covered by a dedicated
67 layer in IE. It should be further transformed to some other layer
68 that are supported by IE. This transformation pass involves weights and
69 shapes processing only.
73 Inputs have the following order:
77 3: initial hidden state [optional]
78 4: initial cell state [optional]
93 ('lstm', dict(kind='op', op='LSTMSequence', format='onnx')),
94 ('input', dict(kind='data')),
95 ('W', dict(kind='data')),
96 ('R', dict(kind='data')),
99 ('input', 'lstm', {'in': 0}),
100 ('W', 'lstm', {'bin': 'W'}),
101 ('R', 'lstm', {'bin': 'R'}),
106 def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
107 self.repack_weights(graph, match)
108 if match['lstm'].has_num_directions:
109 self.squeeze_num_directions(graph, match)
110 self.batch_sequence_transpose(graph, match)
111 self.check_not_supported_ports(graph, match)
112 self.states_squeeze(graph, match)
115 def repack_weights(self, graph: nx.MultiDiGraph, match: dict):
118 W = match['W'].value.copy()
119 R = match['R'].value.copy()
121 # bidirectional case should be processed separately before this transformation
122 if lstm.direction not in ['forward', 'reverse']:
123 raise Error('ONNX/LSTM operator with `forward` or `reverse` is supported only. '
124 'Node {} has direction = {} which is not supported.'.format(lstm.name, lstm.direction))
126 graph.remove_edge(match['W'].id, lstm.id)
127 graph.remove_edge(match['R'].id, lstm.id)
130 if 3 in lstm.in_nodes():
131 # TODO: check if 'bin': 'B' attribute is assigned to this edge
132 B = lstm.in_node(3).value.copy()
133 graph.remove_edge(lstm.in_node(3).id, lstm.id)
135 B = np.full([1, lstm.hidden_size*8], 0, dtype=np.float32)
137 # Add extra dimensions for W, R and B for easier repacking
140 1, # 0: num of directions, limitation: should be 1
141 2, # 1: two input parts of the matrix: W, R
142 4, # 2: four output parts of the matrix for all gates in order: i, o, f, c
143 lstm.hidden_size, # 3: output size per direction and gate
144 1, # 4: fake dimension to match the input dimension in W and R for shorter code
148 1, # 0: num of directions, limitation: should be 1
149 1, # 1: dummy dimension to be aligned with B
150 4, # 2: four output parts of the matrix for all gates in order: i, o, f, c
151 lstm.hidden_size, # 3: output size per direction and gate
152 -1]) # 4: input size/hidden size in W/R
155 input_size = match['input'].shape[2]
156 assert input_size == W.shape[-1]
158 WR = np.concatenate([W, R], axis=4)
160 # Reorder gates: iofc --> fico
161 gate_reorder = [2, 0, 3, 1]
162 WR = np.take(WR, gate_reorder, axis=2)
163 B = np.take(B, gate_reorder, axis=2)
165 # Sum component of B that correspond to W and R
166 B = np.add.reduce(B, axis=1, keepdims=True)
168 # Reorder dimensions by collection output dimensions first, then input dimension
169 # Interpret the numbers below by looking at W, R and B reshape above in the code
170 inout_reorder = [0, 2, 3, 1, 4]
171 WR = WR.transpose(inout_reorder)
172 B = B.transpose(inout_reorder)
174 # Supposing it is unidirectional LSTM, squeeze 'direction' dimension
175 assert WR.shape[0] == 1
176 assert B.shape[0] == 1
177 WR = WR.squeeze(axis=0)
178 B = B.squeeze(axis=0)
180 # Flatten all output (0, 1) and input dimensions (2, 3)
181 final_shape = [WR.shape[0] * WR.shape[1], -1]
182 WR = WR.reshape(final_shape)
183 B = B.reshape(final_shape)
185 # Squeeze fake dimension in B
186 B = B.squeeze(axis=-1)
190 assert WR.shape[0] == lstm.hidden_size*4
191 assert B.shape[0] == lstm.hidden_size*4
192 assert WR.shape[1] == lstm.hidden_size + input_size
194 for blob, port, name in [(WR, 1, 'weights'), (B, 2, 'biases')]:
195 Op.create_and_connect_input_data_node(
198 {'value': blob, 'shape': np.array(blob.shape, dtype=np.int64)},
199 {'in': port, 'bin': name, 'permutation': None}
203 def squeeze_num_directions(self, graph: nx.MultiDiGraph, match: dict):
204 """ Assuming considered LSTM node has num_directions in output shape, remove it. """
206 # num_directions is at 1st position in output shape, please refer to LSTMSequence op definition
208 direction_dim = [1, 0, 0] # index of dimension with direction index
209 for i in lstm.out_nodes():
210 old_data_node = lstm.out_node(i)
211 old_shape = old_data_node.shape.copy()
212 new_shape = np.delete(old_shape, direction_dim[i])
213 data = Op._create_data_node(graph, name=lstm.name + '/Out/{}/'.format(i), attrs={'shape': new_shape})
214 graph.remove_edge(lstm.id, old_data_node.id)
215 graph.add_edge(lstm.id, data.id, key=0, out=i)
216 reshape = Reshape(graph, dict(dim=old_shape))
217 reshape.create_node_with_data([data], dict(name=lstm.name + '/SqueezeNumDirections/{}'.format(i)), data_nodes=[old_data_node])
220 def batch_sequence_transpose(self, graph: nx.MultiDiGraph, match: dict):
224 out = lstm.out_node(0)
226 if lstm.batch_dim == 0:
227 assert lstm.sequence_dim == 1
228 # nothing to do -- it's already in normal form
231 assert lstm.sequence_dim == 0
232 assert lstm.batch_dim == 1
233 assert len(inp.shape) == 3
235 # Reorder the first two dimensions on both ends: input and output.
236 # Two Permute ops are inserted before and after the LSTM node.
237 # In this transformation we don't analyze the rest of the model around
238 # LSTM cell, so these Permute ops are not fused to some other layers here.
239 # But other transformations in the pipeline may optimize the Permute ops out.
241 lstm.batch_dim, lstm.sequence_dim = lstm.sequence_dim, lstm.batch_dim
242 permute_before_and_after(inp, lstm, out, [1, 0, 2])
245 def check_not_supported_ports(self, graph: nx.MultiDiGraph, match: dict):
247 inputs = lstm.in_edges()
249 assert 1 in inputs and inputs[1]['bin'] == 'weights'
250 assert 2 in inputs and inputs[2]['bin'] == 'biases'
251 assert 3 not in inputs
253 if not(set(list(inputs.keys())) <= set([0, 1, 2, 5, 6])):
254 raise Error('Node {} that is interpreted as {} operation has '
255 'some unexpected inputs initialized, '
256 'they can include: sequence_lenght, '
257 'and weight tensor for peepholes. '
258 'This is not supported.'.format(lstm.name, lstm.op))
261 def states_squeeze(self, graph: nx.MultiDiGraph, match: dict):
265 reshape = Reshape(graph, dict(dim=[lstm.in_node(0).shape[0], lstm.hidden_size]))
267 if len(lstm.in_nodes()) > 3:
268 init_h = lstm.in_node(5)
269 edge_attrs = deepcopy(graph.get_edge_data(init_h.id, lstm.id)[0])
271 graph.remove_edge(init_h.id, lstm.id)
272 new_init_h = reshape.create_node_with_data([init_h], dict(name=lstm.name + '/HiddenStateResize'))
273 graph.add_edge(new_init_h.id, lstm.id, **edge_attrs)
275 if len(lstm.in_nodes()) > 4:
276 init_c = lstm.in_node(6)
277 edge_attrs = deepcopy(graph.get_edge_data(init_c.id, lstm.id)[0])
279 graph.remove_edge(init_c.id, lstm.id)
280 new_init_c = reshape.create_node_with_data([init_c], dict(name=lstm.name + '/CellStateResize'))
281 graph.add_edge(new_init_c.id, lstm.id, **edge_attrs)