Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / MXNetSplitMultiLayers.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from mo.graph.graph import Graph, Node
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.concat import Concat
21 from mo.ops.op import Op
22
23
24 class MXNetSplitLayersToRNNSequence(MiddleReplacementPattern):
25     """
26         Split MXNet multilayer cell to multiple one-layers cells LSTM/GRU/RNN.
27         Also concatenate output hiddens and cells states of this layers.
28     """
29     enabled = True
30
31     def pattern(self):
32         return dict(
33             nodes=[
34                 ('rnn_layer', dict(kind='op', type='RNNSequence', format='mxnet', multilayers=True)),
35                 ('input', dict(kind='data')),
36                 ('params', dict(kind='data')),
37             ],
38             edges=[
39                 ('input', 'rnn_layer', {'in': 0}),
40                 ('params', 'rnn_layer', {'in': 1}),
41             ]
42         )
43
44     def replace_pattern(self, graph: Graph, match: dict):
45         output_states = self.split_multilayer_cell(graph, match)
46
47         rnn_layer = match['rnn_layer']
48         self.concat_output_states(graph, match, output_states)
49         rnn_layer.graph.remove_node(rnn_layer.id)
50
51     @staticmethod
52     def get_new_cell(multilayer_cell: Node, number: int):
53         cell_class = Op.get_op_class_by_name(multilayer_cell.op)
54         new_cell = lambda graph, attrs: cell_class(graph, attrs)
55         attrs = multilayer_cell.attrs().copy()
56         new_attrs = {
57             'num_layers': 1,
58             'multilayers': False,
59             'name': multilayer_cell.name + '/LayerSplittedLSTM/{}'.format(number),
60         }
61         attrs.update(new_attrs)
62         return new_cell(multilayer_cell.graph, attrs)
63
64     def split_multilayer_cell(self, graph: Graph, match: dict):
65         """
66         Split one multilayer type=RNNSequence cell to num_layers consecutive cells.
67         All parameters splits to parts for new num_layers cells.
68         """
69         input = match['input']
70         rnn_layer = match['rnn_layer']
71         params = match['params'].value.copy()
72
73         have_hidden = False
74         if 2 in rnn_layer.in_nodes():
75             hidden_state_value = rnn_layer.in_node(2).value
76             have_hidden = True
77
78         have_cell = False
79         if 3 in rnn_layer.in_nodes():
80             cell_state_value = rnn_layer.in_node(3).value
81             have_cell = True
82
83         direction = 2 if rnn_layer.has_num_directions else 1
84         num_layers = rnn_layer.num_layers
85         input_size = input.shape[2]
86         bsize = (2 * rnn_layer.hidden_size * direction * num_layers) * rnn_layer.multiplier
87
88         size = rnn_layer.hidden_size * direction * rnn_layer.multiplier
89         first_layer_params_size = (input_size + rnn_layer.hidden_size + 2) * size
90         other_layer_params_size = (rnn_layer.hidden_size * direction + rnn_layer.hidden_size + 2) * size
91         assert params.size == (first_layer_params_size + (num_layers - 1) * other_layer_params_size)
92
93         input_node = input
94         params_layer_size_count = 0
95         output_states = [[], []]
96
97         param_w = params[0:len(params)-bsize]
98         param_b = params[len(params) - bsize:]
99         layer_bsize = (2 * rnn_layer.hidden_size * direction) * rnn_layer.multiplier
100
101         for l in range(num_layers):
102             params_layer_size = first_layer_params_size if l == 0 else other_layer_params_size
103
104             layer_params_w = param_w[params_layer_size_count: params_layer_size_count +
105                                                               (params_layer_size - layer_bsize)].copy()
106             layer_params_b = param_b[layer_bsize*l: layer_bsize*l+layer_bsize].copy()
107             layer_params = np.concatenate((layer_params_w, layer_params_b), axis=0)
108             params_layer_size_count = params_layer_size_count + params_layer_size - layer_bsize
109
110             op = self.get_new_cell(rnn_layer, l)
111
112             params_value_node = Op._create_data_node(
113                 rnn_layer.graph,
114                 name=rnn_layer.name + '/LayerSplittedParamsLSTM/{}/'.format(l),
115                 attrs={'value': layer_params, 'shape': np.array(layer_params.shape, dtype=np.int64)}
116             )
117             if have_hidden:
118                 layer_hidden_state = hidden_state_value[l * direction: l * direction + direction]
119                 hidden_state_value_node = Op._create_data_node(
120                     rnn_layer.graph,
121                     name=str(rnn_layer.name) + '/LayerSplittedHiddenState/{}/'.format(l),
122                     attrs={'value': layer_hidden_state, 'shape': np.array(layer_hidden_state.shape, dtype=np.int64)}
123                 )
124             else:
125                 hidden_state_value_node = None
126
127             if have_cell:
128                 layer_cell_state = cell_state_value[l * direction: l * direction + direction]
129                 cell_state_value_node = Op._create_data_node(
130                     rnn_layer.graph,
131                     name=str(rnn_layer.name) + '/LayerSplittedCellState/{}/'.format(l),
132                     attrs={'value': layer_cell_state, 'shape': np.array(layer_cell_state.shape, dtype=np.int64)}
133                 )
134             else:
135                 cell_state_value_node = None
136
137             if l < num_layers-1:
138                 output_data = Op._create_data_node(
139                     rnn_layer.graph,
140                     name=rnn_layer.out_node(0).name + '/LayerSplit/' + str(l),
141                     attrs={'shape': rnn_layer.out_node(0).shape.copy()}
142                 )
143             else:
144                 output_data = rnn_layer.out_node(0)
145
146             # Output nodes creating:
147             state_size = np.array([input.shape[rnn_layer.batch_dim], rnn_layer.hidden_size], dtype=np.int64)
148             if rnn_layer.has_num_directions:
149                 state_size = np.insert(state_size, 0, direction)
150
151             output_hidden = Op._create_data_node(
152                 rnn_layer.graph,
153                 name=rnn_layer.out_node(1).name + '/LayerSplit/' + str(l),
154                 attrs={'shape': np.array(state_size)}
155             )
156
157             current_data_nodes = [output_data, output_hidden]
158
159             if rnn_layer.op == 'LSTM':
160                 output_cell = Op._create_data_node(
161                     rnn_layer.graph,
162                     name=rnn_layer.out_node(2).name + '/LayerSplit/' + str(l),
163                     attrs={'shape': np.array(state_size)}
164                 )
165                 current_data_nodes.append(output_cell)
166
167             data_nodes = op.create_node_with_data(
168                 inputs=[
169                     input_node,
170                     params_value_node,
171                     hidden_state_value_node,
172                     cell_state_value_node
173                 ],
174                 data_nodes=current_data_nodes,
175             )
176
177             input_node = data_nodes[0]
178             output_states[0].append(data_nodes[1])
179
180             if rnn_layer.op =='LSTM':
181                 output_states[1].append(data_nodes[2])
182
183         return output_states
184
185     @staticmethod
186     def concat_output_states(graph: Graph, match: dict, new_states: list):
187         """ Concatenates output states from multilayer layer. """
188         rnn_layer = match['rnn_layer']
189         original_states = [rnn_layer.out_node(i) if i in rnn_layer.out_nodes() else None for i in [1, 2]]
190
191         concat_ops = [
192             Concat(rnn_layer.graph, {
193                 'name': rnn_layer.name + '/FinalLayerSplitConcat/HiddenState',
194                 'axis': -1
195             }),
196             Concat(rnn_layer.graph, {
197                 'name': rnn_layer.name + '/FinalLayerSplitConcat/CellState',
198                 'axis': -1
199             })
200         ]
201
202         for i in range(len(original_states)):  # [0] or [0, 1]
203             if original_states[i] is None:
204                 continue
205             concat_ops[i].attrs.update({'in_ports_count': len(new_states[i])})
206             concat_ops[i].create_node_with_data(inputs=new_states[i], data_nodes=[original_states[i]])