2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from extensions.middle.FusePermutesSequence import FusePermutesSequence
21 from extensions.middle.lstm_sequence_normalize import LSTMSequenceNormalize
22 from extensions.middle.mxnet_lstm_sequence_normalize import MXNetLSTMSequenceNormalize
23 from extensions.ops.lstm_cell import LSTMCell
24 from extensions.ops.tensor_iterator import TensorIterator
25 from mo.middle.replacement import MiddleReplacementPattern
26 from mo.ops.op import Op
27 from mo.ops.reshape import Reshape
30 class LSTMSequenceTensorIterator(MiddleReplacementPattern):
31 """ Converts normalized LSTMSequence op to TensorIterator.
33 Normalized LSTMSequence means that it should be processed by
34 LSTMSequenceNormalize transform that ensures its stict form.
36 This transformation builds an altenative sub-graph for LSTMSequence
37 with TensorIterator connected in the same way as an original LSTMSequence
38 node and with internal body represented as LSTMCell op node with necessary
39 squeezes and unsqueezes around.
45 return [LSTMSequenceNormalize, MXNetLSTMSequenceNormalize]
48 return [FusePermutesSequence]
53 ('lstm', dict(kind='op', op='LSTMSequence')),
54 ('input', dict(kind='data')),
55 ('weights', dict(kind='data')),
56 ('biases', dict(kind='data')),
57 # don't capture optional input initial states here
58 ('output', dict(kind='data')),
59 # don't capture optional output last states here
62 ('input', 'lstm', {'in': 0}),
63 ('weights', 'lstm', {'bin': 'weights', 'in': 1}),
64 ('biases', 'lstm', {'bin': 'biases', 'in': 2}),
65 ('lstm', 'output', {'out': 0}),
69 def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
72 # Build TensorIterator body first
73 body = nx.MultiDiGraph(name=lstm.name + '/sub_graph', layout=graph.graph['layout'])
74 inputs = [Op._create_data_node(body, lstm.name + '/inport/' + str(inp),
75 {'shape': lstm.in_node(inp).shape.copy(),
76 'value': lstm.in_node(inp).value.copy()
77 if lstm.in_node(inp).value is not None and inp in [1, 2] else None})
78 for inp in [0, 3, 4, 1, 2]]
79 inputs[0].shape[lstm.sequence_dim] = 1
80 reshape_dim = inputs[0].shape.copy()
81 reshape_dim[lstm.batch_dim] = -1
82 reshape_dim = np.delete(reshape_dim, lstm.sequence_dim)
83 input_squeeze = Reshape(
85 dict(name=lstm.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
87 inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
88 lstm_cell_op = LSTMCell(body, dict(hidden_size=match['lstm'].hidden_size, name=lstm.name + '/LSTMCell',
90 outputs = [Op._create_data_node(body, lstm.name + '/outport/' + str(out),
91 {'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes()
92 else lstm.in_node(3).shape.copy(), 'is_output': True}) for out in [0, 1]]
93 unsqueezed_output_shape = outputs[0].shape.copy()
94 unsqueezed_output_shape[lstm.sequence_dim] = 1
95 squeezed_output_shape = np.delete(unsqueezed_output_shape, lstm.sequence_dim)
96 outputs[0].shape = squeezed_output_shape
97 unsqueezed_output_shape[lstm.batch_dim] = -1
98 output_unsqueeze = Reshape(body, dict(name=lstm.name + 'output_unsqueeze', dim=unsqueezed_output_shape,
100 # TODO edge attributes should be assigned by the op itself
101 lstm_cell_node = lstm_cell_op.create_node_with_data(inputs, data_nodes=outputs,
102 edge_attrs=[{}, {'internal_port_id': 1},
103 {'internal_port_id': 2}, {'bin': 'weights'},
105 lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4
106 lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5
107 lstm_cell_node[0] = output_unsqueeze.create_node_with_data([lstm_cell_node[0]])
108 lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3
109 lstm_cell_node[0]['is_output'] = True
111 assert lstm.direction in ['forward', 'reverse']
112 if lstm.direction == 'forward':
117 assert lstm.direction == 'reverse'
123 'external_port_id': 3,
124 'internal_layer_id': 2,
125 'internal_port_id': 3,
126 'axis': lstm.sequence_dim,
133 if len(lstm.out_nodes()) == 3:
134 output_port_map.extend([{
135 'external_port_id': 4,
136 'internal_layer_id': 1,
137 'internal_port_id': 4,
139 'external_port_id': 5,
140 'internal_layer_id': 1,
141 'internal_port_id': 5,
144 ti_op = TensorIterator(graph, {
145 'name': lstm.name + '/TensorIterator',
150 'external_port_id': 0,
151 'internal_layer_id': 0,
152 'internal_port_id': 0,
153 'axis': lstm.sequence_dim,
160 'external_port_id': 1,
161 'internal_layer_id': 1,
162 'internal_port_id': 1,
165 'external_port_id': 2,
166 'internal_layer_id': 1,
167 'internal_port_id': 2,
171 'output_port_map': output_port_map,
189 assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \
190 "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id)
191 outs = ti_op.create_node_with_data([lstm.in_node(i) for i in [0, 3, 4]],
192 data_nodes=[lstm.out_node(i) for i in range(len(lstm.out_nodes()))],
193 edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1},
194 {'external_port_id': 2}])
196 if not isinstance(outs, list):
199 graph.remove_node(lstm.id)
200 outs[0].in_edge(0)['external_port_id'] = 3
201 for i, out in enumerate(outs[1:]):
202 external_port_id = 4 + i
203 out.in_edge()['external_port_id'] = external_port_id