2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from extensions.middle.FusePermutesSequence import FusePermutesSequence
19 from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
20 from extensions.ops.lstm_cell import LSTMCell
21 from extensions.ops.tensor_iterator import TensorIterator
22 from mo.graph.graph import Graph, add_opoutput
23 from mo.middle.replacement import MiddleReplacementPattern
24 from mo.ops.op import Op
25 from mo.ops.reshape import Reshape
28 class LSTMToTensorIterator(MiddleReplacementPattern):
29 """ Converts normalized RNNSequence with op=LSTM to TensorIterator.
31 Normalized RNNSequence means that it should be processed by
32 RNNSequenceNormalize transform that ensures its strict form.
34 This transformation builds an alternative sub-graph for LSTMSequence
35 with TensorIterator connected in the same way as an original LSTMSequence
36 node and with internal body represented as LSTMCell op node with necessary
37 squeezes and unsqueezes around.
42 id = 'lstm_to_tensor_iterator'
45 return [RNNSequenceNormalize]
48 return [FusePermutesSequence]
53 ('lstm', dict(kind='op', op='LSTM', type='RNNSequence')),
54 ('input', dict(kind='data')),
55 ('weights', dict(kind='data')),
56 ('biases', dict(kind='data')),
57 # don't capture optional input initial states here
58 ('output', dict(kind='data')),
59 # don't capture optional output last states here
62 ('input', 'lstm', {'in': 0}),
63 ('weights', 'lstm', {'bin': 'weights', 'in': 1}),
64 ('biases', 'lstm', {'bin': 'biases', 'in': 2}),
65 ('lstm', 'output', {'out': 0}),
69 def replace_pattern(self, graph: Graph, match: dict):
72 # Build TensorIterator body first
73 body = Graph(name=lstm.name + '/sub_graph')
74 body.graph = graph.graph
76 # 1. Input squeeze Reshape
77 inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
78 {'shape': lstm.in_node(inp).shape.copy(),
79 'value': lstm.in_node(inp).value.copy()
80 if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
81 for inp in [0, 4, 5, 1, 2]] # X, WR, B, h_init, c_init
83 inputs[0].shape[lstm.sequence_dim] = 1
84 reshape_dim = inputs[0].shape.copy()
85 reshape_dim[lstm.batch_dim] = -1
86 reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
87 input_squeeze = Reshape(
89 dict(name=lstm.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
91 inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
93 # 2. Output unsqueeze Reshape
94 outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
95 {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
96 else lstm.in_node(4).shape.copy()}) for out in [0, 1]]
98 add_opoutput(body, out.id, 0, False)
100 unsqueezed_output_shape = outputs[0].shape.copy()
101 unsqueezed_output_shape[lstm.sequence_dim] = 1
102 squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim)
103 outputs[0].shape = squeezed_output_shape
104 unsqueezed_output_shape[lstm.batch_dim] = -1
105 output_unsqueeze = Reshape(body, dict(name=lstm.name + 'output_unsqueeze', dim=unsqueezed_output_shape,
106 internal_layer_id=2))
109 lstm_cell_op = LSTMCell(body, dict(hidden_size=lstm.hidden_size,
110 activations=lstm.activations,
111 activation_alpha=lstm.activation_alpha,
112 activation_beta=lstm.activation_beta,
114 input_forget=lstm.input_forget,
115 name=lstm.name + '/LSTMCell',
116 internal_layer_id=1))
117 lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
118 edge_attrs=[{}, {'internal_port_id': 1},
119 {'internal_port_id': 2}, {'bin': 'weights'},
121 lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
122 lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
123 lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0]])
124 lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
125 add_opoutput(body, lstm_cell_node[0].id, 0, False)
127 # 4. TensorIterator layer creating
128 assert lstm.direction in ['forward', 'reverse']
129 if lstm.direction == 'forward':
134 assert lstm.direction == 'reverse'
140 'external_port_id': 3,
141 'internal_layer_id': 2,
142 'internal_port_id': 3,
144 'axis': lstm.sequence_dim,
151 # Adding h_state, c_state to outputs
152 if len(lstm.out_nodes()) == 3:
153 output_port_map.extend([{
154 'external_port_id': 4,
155 'internal_layer_id': 1,
156 'internal_port_id': 4,
158 'external_port_id': 5,
159 'internal_layer_id': 1,
160 'internal_port_id': 5,
163 ti_op = TensorIterator(graph, {
164 'name': lstm.name + '/TensorIterator',
167 'out_ports_count': len(lstm.out_nodes()),
171 'external_port_id': 0,
172 'internal_layer_id': 0,
173 'internal_port_id': 0,
175 'axis': lstm.sequence_dim,
182 'external_port_id': 1,
183 'internal_layer_id': 1,
184 'internal_port_id': 1,
187 'external_port_id': 2,
188 'internal_layer_id': 1,
189 'internal_port_id': 2,
193 'output_port_map': output_port_map,
211 assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
212 "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
214 outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 4, 5]], # X, h_init, c_init
215 data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
216 edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
217 {'external_port_id': 2}])
219 if not isinstance(outs, list):
222 graph.remove_node(lstm.id)
223 outs[0].in_edge(0)['external_port_id'] = 3
224 for i, out in enumerate(outs[1:]):
225 external_port_id = 4 + i
226 out.in_edge()['external_port_id'] = external_port_id