2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from mo.graph.graph import Graph, Node
19 from mo.middle.replacement import MiddleReplacementPattern
20 from mo.ops.concat import Concat
21 from mo.ops.op import Op
24 class MXNetSplitLayersToRNNSequence(MiddleReplacementPattern):
26 Split MXNet multilayer cell to multiple one-layers cells LSTM/GRU/RNN.
27 Also concatenate output hiddens and cells states of this layers.
34 ('rnn_layer', dict(kind='op', type='RNNSequence', format='mxnet', multilayers=True)),
35 ('input', dict(kind='data')),
36 ('params', dict(kind='data')),
39 ('input', 'rnn_layer', {'in': 0}),
40 ('params', 'rnn_layer', {'in': 1}),
44 def replace_pattern(self, graph: Graph, match: dict):
45 output_states = self.split_multilayer_cell(graph, match)
47 rnn_layer = match['rnn_layer']
48 self.concat_output_states(graph, match, output_states)
49 rnn_layer.graph.remove_node(rnn_layer.id)
52 def get_new_cell(multilayer_cell: Node, number: int):
53 cell_class = Op.get_op_class_by_name(multilayer_cell.op)
54 new_cell = lambda graph, attrs: cell_class(graph, attrs)
55 attrs = multilayer_cell.attrs().copy()
59 'name': multilayer_cell.name + '/LayerSplittedLSTM/{}'.format(number),
61 attrs.update(new_attrs)
62 return new_cell(multilayer_cell.graph, attrs)
64 def split_multilayer_cell(self, graph: Graph, match: dict):
66 Split one multilayer type=RNNSequence cell to num_layers consecutive cells.
67 All parameters splits to parts for new num_layers cells.
69 input = match['input']
70 rnn_layer = match['rnn_layer']
71 params = match['params'].value.copy()
74 if 2 in rnn_layer.in_nodes():
75 hidden_state_value = rnn_layer.in_node(2).value
79 if 3 in rnn_layer.in_nodes():
80 cell_state_value = rnn_layer.in_node(3).value
83 direction = 2 if rnn_layer.has_num_directions else 1
84 num_layers = rnn_layer.num_layers
85 input_size = input.shape[2]
86 bsize = (2 * rnn_layer.hidden_size * direction * num_layers) * rnn_layer.multiplier
88 size = rnn_layer.hidden_size * direction * rnn_layer.multiplier
89 first_layer_params_size = (input_size + rnn_layer.hidden_size + 2) * size
90 other_layer_params_size = (rnn_layer.hidden_size * direction + rnn_layer.hidden_size + 2) * size
91 assert params.size == (first_layer_params_size + (num_layers - 1) * other_layer_params_size)
94 params_layer_size_count = 0
95 output_states = [[], []]
97 param_w = params[0:len(params)-bsize]
98 param_b = params[len(params) - bsize:]
99 layer_bsize = (2 * rnn_layer.hidden_size * direction) * rnn_layer.multiplier
101 for l in range(num_layers):
102 params_layer_size = first_layer_params_size if l == 0 else other_layer_params_size
104 layer_params_w = param_w[params_layer_size_count: params_layer_size_count +
105 (params_layer_size - layer_bsize)].copy()
106 layer_params_b = param_b[layer_bsize*l: layer_bsize*l+layer_bsize].copy()
107 layer_params = np.concatenate((layer_params_w, layer_params_b), axis=0)
108 params_layer_size_count = params_layer_size_count + params_layer_size - layer_bsize
110 op = self.get_new_cell(rnn_layer, l)
112 params_value_node = Op._create_data_node(
114 name=rnn_layer.name + '/LayerSplittedParamsLSTM/{}/'.format(l),
115 attrs={'value': layer_params, 'shape': np.array(layer_params.shape, dtype=np.int64)}
118 layer_hidden_state = hidden_state_value[l * direction: l * direction + direction]
119 hidden_state_value_node = Op._create_data_node(
121 name=str(rnn_layer.name) + '/LayerSplittedHiddenState/{}/'.format(l),
122 attrs={'value': layer_hidden_state, 'shape': np.array(layer_hidden_state.shape, dtype=np.int64)}
125 hidden_state_value_node = None
128 layer_cell_state = cell_state_value[l * direction: l * direction + direction]
129 cell_state_value_node = Op._create_data_node(
131 name=str(rnn_layer.name) + '/LayerSplittedCellState/{}/'.format(l),
132 attrs={'value': layer_cell_state, 'shape': np.array(layer_cell_state.shape, dtype=np.int64)}
135 cell_state_value_node = None
138 output_data = Op._create_data_node(
140 name=rnn_layer.out_node(0).name + '/LayerSplit/' + str(l),
141 attrs={'shape': rnn_layer.out_node(0).shape.copy()}
144 output_data = rnn_layer.out_node(0)
146 # Output nodes creating:
147 state_size = np.array([input.shape[rnn_layer.batch_dim], rnn_layer.hidden_size], dtype=np.int64)
148 if rnn_layer.has_num_directions:
149 state_size = np.insert(state_size, 0, direction)
151 output_hidden = Op._create_data_node(
153 name=rnn_layer.out_node(1).name + '/LayerSplit/' + str(l),
154 attrs={'shape': np.array(state_size)}
157 current_data_nodes = [output_data, output_hidden]
159 if rnn_layer.op == 'LSTM':
160 output_cell = Op._create_data_node(
162 name=rnn_layer.out_node(2).name + '/LayerSplit/' + str(l),
163 attrs={'shape': np.array(state_size)}
165 current_data_nodes.append(output_cell)
167 data_nodes = op.create_node_with_data(
171 hidden_state_value_node,
172 cell_state_value_node
174 data_nodes=current_data_nodes,
177 input_node = data_nodes[0]
178 output_states[0].append(data_nodes[1])
180 if rnn_layer.op =='LSTM':
181 output_states[1].append(data_nodes[2])
186 def concat_output_states(graph: Graph, match: dict, new_states: list):
187 """ Concatenates output states from multilayer layer. """
188 rnn_layer = match['rnn_layer']
189 original_states = [rnn_layer.out_node(i) if i in rnn_layer.out_nodes() else None for i in [1, 2]]
192 Concat(rnn_layer.graph, {
193 'name': rnn_layer.name + '/FinalLayerSplitConcat/HiddenState',
196 Concat(rnn_layer.graph, {
197 'name': rnn_layer.name + '/FinalLayerSplitConcat/CellState',
202 for i in range(len(original_states)): # [0] or [0, 1]
203 if original_states[i] is None:
205 concat_ops[i].attrs.update({'in_ports_count': len(new_states[i])})
206 concat_ops[i].create_node_with_data(inputs=new_states[i], data_nodes=[original_states[i]])