2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from mo.graph.graph import Node, Graph
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.concat import Concat
21 from mo.ops.op import Op
22 from mo.ops.split import Split
25 class DecomposeBidirectionalRNNSequence(MiddleReplacementPattern):
27 Decomposes bidirectional RNNSequence to forward and reverse RNNSequence ops.
29 Both initial state are split to two part, two parts of the results are concatenated.
31 Axis of split/concat is completely defined by ONNX recurrent layers specification.
36 from extensions.middle.MXNetRNNSequenceNormalize import MXNetRNNSequenceNormalize
37 from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
38 return [ONNXRNNSequenceNormalize, MXNetRNNSequenceNormalize]
43 ('lstm', dict(kind='op', type='RNNSequence', direction='bidirectional')),
44 ('input', dict(kind='data')),
45 ('W', dict(kind='data')),
46 ('R', dict(kind='data')),
47 ('B', dict(kind='data')),
50 ('input', 'lstm', {'in': 0}),
51 ('W', 'lstm', {'in': 1}),
52 ('R', 'lstm', {'in': 2}),
53 ('B', 'lstm', {'in': 3}),
58 def split_helper(node: Node, index: int, direction: str, axis: int=0):
59 return Op._create_data_node(
61 name=node.name + '/SplittedBiLSTM/{}/'.format(direction),
62 attrs={'value': np.take(node.value, [index], axis),
63 'shape': np.array(np.take(node.value, [index], axis).shape, dtype=np.int64)}
66 def split_data(self, data: Node):
67 """ Helper. Split data node into two part along 0 axis """
68 assert len(data.shape) == 3
69 assert data.shape[0] == 2
71 output_data = [Op._create_data_node(data.graph,
72 name=data.name + '/SplittedBiLSTM/{}'.format(['forward', 'reverse'][i])) for i in [0, 1]]
73 split_op = Split(data.graph, dict(name=data.name + '/DecomposedBiLSTM_0', axis=0, num_split=2,
75 return split_op.create_node_with_data([data], data_nodes=output_data)
77 def replace_pattern(self, graph: Graph, match: dict):
78 bidirectional_cell = match['lstm']
79 new_init_hiddens = self.split_data(bidirectional_cell.in_node(5))
80 new_init_cells = self.split_data(bidirectional_cell.in_node(6)) if 6 in bidirectional_cell.in_nodes()\
83 blob_bidirectional_split = lambda node: (
84 self.split_helper(node, 0, 'forward'),
85 self.split_helper(node, 1, 'reverse')
88 splitted_W = blob_bidirectional_split(bidirectional_cell.in_node(1))
89 splitted_R = blob_bidirectional_split(bidirectional_cell.in_node(2))
90 splitted_B = blob_bidirectional_split(bidirectional_cell.in_node(3))
92 outputs = self.split_bidirectional(
101 self.concat_outputs(bidirectional_cell, outputs[0], outputs[1], bidirectional_cell.out_nodes())
104 def get_new_cell(bidirectional_cell: Node, direction: str):
105 assert direction in ['forward', 'reverse']
107 cell_class = Op.get_op_class_by_name(bidirectional_cell.op)
108 new_cell = lambda graph, attrs: cell_class(graph, attrs)
109 attrs = bidirectional_cell.attrs().copy()
111 'direction': direction,
112 'name': bidirectional_cell.name + '/Split/' + direction,
114 attrs.update(new_attrs)
115 return new_cell(bidirectional_cell.graph, attrs)
117 def split_bidirectional(self,
118 bidirectional_cell: Node,
119 new_init_hiddens: list,
120 new_init_cells: list,
125 Split one bidirectional RNNSequence node into 2 one-directional RNNSequence nodes.
127 All input data nodes should be already prepared; they are
128 have 2 in the num_dir dimension.
132 direction = ['forward', 'reverse'][i]
133 op = self.get_new_cell(bidirectional_cell, direction)
135 output_data = Op._create_data_node(
136 bidirectional_cell.graph,
137 name=bidirectional_cell.out_node(0).name + '/Split/' + str(i),
138 attrs={'shape': bidirectional_cell.out_node(0).shape.copy()}
141 assert output_data.shape[1] == 2
142 output_data.shape[1] = 1
144 output_hidden = Op._create_data_node(
145 bidirectional_cell.graph,
146 name=bidirectional_cell.out_node(1).name + '/Split/' + str(i),
147 attrs={'shape': bidirectional_cell.out_node(1).shape.copy()}
150 assert output_hidden.shape[0] == 2
151 output_hidden.shape[0] = 1
158 if bidirectional_cell.op == 'LSTM':
159 output_cell = Op._create_data_node(
160 bidirectional_cell.graph,
161 name=bidirectional_cell.out_node(2).name + '/Split/' + str(i),
162 attrs={'shape': bidirectional_cell.out_node(2).shape.copy()}
165 assert output_cell.shape[0] == 2
166 output_cell.shape[0] = 1
168 data_nodes.append(output_cell)
171 op.create_node_with_data(
173 bidirectional_cell.in_node(0),
179 new_init_cells[i] if bidirectional_cell.op == 'LSTM' else None,
181 data_nodes=data_nodes
187 def concat_outputs(bi_rnn, forward_outputs, reverse_outputs, final_outputs):
188 """ Concatenates two set of outputs from bidirectiondl RNNSequence nodes """
190 Concat(bi_rnn.graph, {
191 'name': bi_rnn.name + '/FinalConcat/Data',
195 Concat(bi_rnn.graph, {
196 'name': bi_rnn.name + '/FinalConcat/HiddenState',
200 Concat(bi_rnn.graph, {
201 'name': bi_rnn.name + '/FinalConcat/CellState',
207 bi_rnn.graph.remove_node(bi_rnn.id)
209 for i in final_outputs:
210 concat_ops[i].create_node_with_data(
211 [forward_outputs[i], reverse_outputs[i]],
212 data_nodes=[final_outputs[i]]