2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from extensions.ops.tensor_iterator import TensorIterator
19 from mo.graph.graph import Graph, add_opoutput
20 from mo.middle.replacement import MiddleReplacementPattern
21 from mo.ops.op import Op
22 from mo.ops.reshape import Reshape
25 class GRUAndRNNToTensorIterator(MiddleReplacementPattern):
26 """ Converts normalized RNNSequence with op=GRU/RNN to TensorIterator.
28 Normalized RNNSequence means that it should be processed by
29 RNNSequenceNormalize transform that ensures its strict form.
31 This transformation builds an alternative sub-graph for GRUSequence
32 with TensorIterator connected in the same way as an original GRUSequence
33 node and with internal body represented as GRUCell op node with necessary
34 squeezes and unsqueezes around.
38 id = 'gru_and_rnn_to_tensor_iterator'
41 from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
42 return [RNNSequenceNormalize]
45 from extensions.middle.FusePermutesSequence import FusePermutesSequence
46 return [FusePermutesSequence]
51 ('rnn_layer', dict(kind='op', type='RNNSequence')),
52 ('input', dict(kind='data')),
53 ('weights', dict(kind='data')),
54 ('biases', dict(kind='data')),
55 # don't capture optional input initial states here
56 ('output', dict(kind='data')),
57 # don't capture optional output last states here
60 ('input', 'rnn_layer', {'in': 0}),
61 ('weights', 'rnn_layer', {'bin': 'weights', 'in': 1}),
62 ('biases', 'rnn_layer', {'bin': 'biases', 'in': 2}),
63 ('rnn_layer', 'output', {'out': 0}),
68 def get_rnn_cell(name: str):
69 op = Op.get_op_class_by_name(name + 'Cell')
72 def replace_pattern(self, graph: Graph, match: dict):
73 if match['rnn_layer']['op'] == 'LSTM':
76 rnn_layer = match['rnn_layer']
78 # Build TensorIterator body first
79 body = Graph(name=rnn_layer.name + '/sub_graph')
80 body.graph = graph.graph
82 # 1. Input squeeze Reshape
83 inputs = [Op._create_data_node(body, rnn_layer.name + '/inport/' + str(inp),
84 {'shape': rnn_layer.in_node(inp).shape.copy(),
85 'value': rnn_layer.in_node(inp).value.copy()
86 if rnn_layer.in_node(inp).value is not None and inp in [1, 2] else None})
87 for inp in [0, 4, 1, 2]] # X, h_init, WR, B
89 inputs[0].shape[rnn_layer.sequence_dim] = 1
90 reshape_dim = inputs[0].shape.copy()
91 reshape_dim[rnn_layer.batch_dim] = -1
92 reshape_dim = np.delete(reshape_dim, rnn_layer.sequence_dim)
93 input_squeeze = Reshape(
95 dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
97 inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
99 # 2. Output unsqueeze Reshape
100 outputs = [Op._create_data_node(body, rnn_layer.name + '/outport/' + str(out),
101 {'shape': rnn_layer.out_node(out).shape.copy() if out in rnn_layer.out_nodes() else None})
104 add_opoutput(body, out.id, 0, False)
106 unsqueezed_output_shape = outputs[0].shape.copy()
107 unsqueezed_output_shape[rnn_layer.sequence_dim] = 1
108 squeezed_output_shape = np.delete(unsqueezed_output_shape, rnn_layer.sequence_dim)
109 outputs[0].shape = squeezed_output_shape
110 unsqueezed_output_shape[rnn_layer.batch_dim] = -1
111 output_unsqueeze = Reshape(body, dict(name=rnn_layer.name + '/output_unsqueeze/', dim=unsqueezed_output_shape,
112 internal_layer_id=2))
114 additional_attrs = dict(activations=rnn_layer.activations,
115 activation_alpha=rnn_layer.activation_alpha,
116 activation_beta=rnn_layer.activation_beta,
118 if rnn_layer.op == 'GRU':
119 additional_attrs['linear_before_reset'] = rnn_layer.linear_before_reset
122 rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(body, dict(hidden_size=rnn_layer.hidden_size,
123 name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
125 internal_layer_id=1))
127 gru_cell = rnn_cell_op.create_node_with_data(inputs, data_nodes=outputs,
128 edge_attrs=[{}, {'internal_port_id': 1},
129 {'internal_port_id': 2}, {'bin': 'weights'},
132 # internal ports for outputs of cell
133 gru_cell.in_node().out_edge(0)['internal_port_id'] = 4 # h_state
135 gru_cell = output_unsqueeze.create_node_with_data([gru_cell])
136 gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
137 add_opoutput(body, gru_cell.id, 0, False)
139 # 4. TensorIterator layer creating
140 assert rnn_layer.direction in ['forward', 'reverse']
141 if rnn_layer.direction == 'forward':
146 assert rnn_layer.direction == 'reverse'
153 'external_port_id': 3,
154 'internal_layer_id': 2,
155 'internal_port_id': 3,
157 'axis': rnn_layer.sequence_dim,
164 # Adding last h_state to outputs
165 if len(rnn_layer.out_nodes()) == 2:
166 output_port_map.extend([{
167 'external_port_id': 4,
168 'internal_layer_id': 1,
169 'internal_port_id': 4,
172 ti_op = TensorIterator(graph, {
173 'name': rnn_layer.name + '/TensorIterator',
176 'out_ports_count': len(rnn_layer.out_nodes()),
180 'external_port_id': 0,
181 'internal_layer_id': 0,
182 'internal_port_id': 0,
184 'axis': rnn_layer.sequence_dim,
191 'external_port_id': 1,
192 'internal_layer_id': 1,
193 'internal_port_id': 1,
197 'output_port_map': output_port_map,
209 assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
210 "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)
212 outs = ti_op.create_node_with_data([rnn_layer.in_node(i) for i in [0, 4]], # X, h_init
213 data_nodes=[rnn_layer.out_node(i) for i in range(len(rnn_layer.out_nodes()))],
214 edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1}])
216 if not isinstance(outs, list):
219 graph.remove_node(rnn_layer.id)
220 outs[0].in_edge(0)['external_port_id'] = 3
221 for i, out in enumerate(outs[1:]):
222 external_port_id = 4 + i
223 out.in_edge()['external_port_id'] = external_port_id