Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / DecomposeBidirectionalRNNSequence.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from mo.graph.graph import Node, Graph
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.concat import Concat
21 from mo.ops.op import Op
22 from mo.ops.split import Split
23
24
25 class DecomposeBidirectionalRNNSequence(MiddleReplacementPattern):
26     """
27         Decomposes bidirectional RNNSequence to forward and reverse RNNSequence ops.
28
29         Both initial state are split to two part, two parts of the results are concatenated.
30
31         Axis of split/concat is completely defined by ONNX recurrent layers specification.
32     """
33     enabled = True
34
35     def run_after(self):
36         from extensions.middle.MXNetRNNSequenceNormalize import MXNetRNNSequenceNormalize
37         from extensions.middle.ONNXRNNSequenceNormalize import ONNXRNNSequenceNormalize
38         return [ONNXRNNSequenceNormalize, MXNetRNNSequenceNormalize]
39
40     def pattern(self):
41         return dict(
42             nodes=[
43                 ('lstm', dict(kind='op', type='RNNSequence', direction='bidirectional')),
44                 ('input', dict(kind='data')),
45                 ('W', dict(kind='data')),
46                 ('R', dict(kind='data')),
47                 ('B', dict(kind='data')),
48             ],
49             edges=[
50                 ('input', 'lstm', {'in': 0}),
51                 ('W', 'lstm', {'in': 1}),
52                 ('R', 'lstm', {'in': 2}),
53                 ('B', 'lstm', {'in': 3}),
54             ]
55         )
56
57     @staticmethod
58     def split_helper(node: Node, index: int, direction: str, axis: int=0):
59         return Op._create_data_node(
60             node.graph,
61             name=node.name + '/SplittedBiLSTM/{}/'.format(direction),
62             attrs={'value': np.take(node.value, [index], axis),
63                    'shape': np.array(np.take(node.value, [index], axis).shape, dtype=np.int64)}
64         )
65
66     def split_data(self, data: Node):
67         """ Helper. Split data node into two part along 0 axis """
68         assert len(data.shape) == 3
69         assert data.shape[0] == 2
70
71         output_data = [Op._create_data_node(data.graph,
72                        name=data.name + '/SplittedBiLSTM/{}'.format(['forward', 'reverse'][i])) for i in [0, 1]]
73         split_op = Split(data.graph, dict(name=data.name + '/DecomposedBiLSTM_0', axis=0, num_split=2,
74                                           out_ports_count=2))
75         return split_op.create_node_with_data([data], data_nodes=output_data)
76
77     def replace_pattern(self, graph: Graph, match: dict):
78         bidirectional_cell = match['lstm']
79         new_init_hiddens = self.split_data(bidirectional_cell.in_node(5))
80         new_init_cells = self.split_data(bidirectional_cell.in_node(6)) if 6 in bidirectional_cell.in_nodes()\
81             else (None, None)
82
83         blob_bidirectional_split = lambda node: (
84             self.split_helper(node, 0, 'forward'),
85             self.split_helper(node, 1, 'reverse')
86         )
87
88         splitted_W = blob_bidirectional_split(bidirectional_cell.in_node(1))
89         splitted_R = blob_bidirectional_split(bidirectional_cell.in_node(2))
90         splitted_B = blob_bidirectional_split(bidirectional_cell.in_node(3))
91
92         outputs = self.split_bidirectional(
93             bidirectional_cell,
94             new_init_hiddens,
95             new_init_cells,
96             splitted_W,
97             splitted_R,
98             splitted_B,
99         )
100
101         self.concat_outputs(bidirectional_cell, outputs[0], outputs[1], bidirectional_cell.out_nodes())
102
103     @staticmethod
104     def get_new_cell(bidirectional_cell: Node, direction: str):
105         assert direction in ['forward', 'reverse']
106
107         cell_class = Op.get_op_class_by_name(bidirectional_cell.op)
108         new_cell = lambda graph, attrs: cell_class(graph, attrs)
109         attrs = bidirectional_cell.attrs().copy()
110         new_attrs = {
111             'direction': direction,
112             'name': bidirectional_cell.name + '/Split/' + direction,
113         }
114         attrs.update(new_attrs)
115         return new_cell(bidirectional_cell.graph, attrs)
116
117     def split_bidirectional(self,
118                             bidirectional_cell: Node,
119                             new_init_hiddens: list,
120                             new_init_cells: list,
121                             splitted_W: tuple,
122                             splitted_R: tuple,
123                             splitted_B: tuple):
124         """
125             Split one bidirectional RNNSequence node into 2 one-directional RNNSequence nodes.
126
127             All input data nodes should be already prepared; they are
128             have 2 in the num_dir dimension.
129         """
130         all_outputs = []
131         for i in [0, 1]:
132             direction = ['forward', 'reverse'][i]
133             op = self.get_new_cell(bidirectional_cell, direction)
134
135             output_data = Op._create_data_node(
136                 bidirectional_cell.graph,
137                 name=bidirectional_cell.out_node(0).name + '/Split/' + str(i),
138                 attrs={'shape': bidirectional_cell.out_node(0).shape.copy()}
139             )
140
141             assert output_data.shape[1] == 2
142             output_data.shape[1] = 1
143
144             output_hidden = Op._create_data_node(
145                 bidirectional_cell.graph,
146                 name=bidirectional_cell.out_node(1).name + '/Split/' + str(i),
147                 attrs={'shape': bidirectional_cell.out_node(1).shape.copy()}
148             )
149
150             assert output_hidden.shape[0] == 2
151             output_hidden.shape[0] = 1
152
153             data_nodes = [
154                 output_data,
155                 output_hidden,
156             ]
157
158             if bidirectional_cell.op == 'LSTM':
159                 output_cell = Op._create_data_node(
160                     bidirectional_cell.graph,
161                     name=bidirectional_cell.out_node(2).name + '/Split/' + str(i),
162                     attrs={'shape': bidirectional_cell.out_node(2).shape.copy()}
163                 )
164
165                 assert output_cell.shape[0] == 2
166                 output_cell.shape[0] = 1
167
168                 data_nodes.append(output_cell)
169
170             all_outputs.append(
171                 op.create_node_with_data(
172                     inputs=[
173                         bidirectional_cell.in_node(0),
174                         splitted_W[i],
175                         splitted_R[i],
176                         splitted_B[i],
177                         None,
178                         new_init_hiddens[i],
179                         new_init_cells[i] if bidirectional_cell.op == 'LSTM' else None,
180                     ],
181                     data_nodes=data_nodes
182                 )
183             )
184         return all_outputs
185
186     @staticmethod
187     def concat_outputs(bi_rnn, forward_outputs, reverse_outputs, final_outputs):
188         """ Concatenates two set of outputs from bidirectiondl RNNSequence nodes """
189         concat_ops = [
190             Concat(bi_rnn.graph, {
191                 'name': bi_rnn.name + '/FinalConcat/Data',
192                 'axis': 1,
193                 'in_ports_count': 2,
194             }),
195             Concat(bi_rnn.graph, {
196                 'name': bi_rnn.name + '/FinalConcat/HiddenState',
197                 'axis': 0,
198                 'in_ports_count': 2,
199             }),
200             Concat(bi_rnn.graph, {
201                 'name': bi_rnn.name + '/FinalConcat/CellState',
202                 'axis': 0,
203                 'in_ports_count': 2,
204             })
205         ]
206
207         bi_rnn.graph.remove_node(bi_rnn.id)
208
209         for i in final_outputs:
210             concat_ops[i].create_node_with_data(
211                 [forward_outputs[i], reverse_outputs[i]],
212                 data_nodes=[final_outputs[i]]
213             )