Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / GRURNNSequenceToTensorIterator.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from extensions.ops.tensor_iterator import TensorIterator
19 from mo.graph.graph import Graph, add_opoutput
20 from mo.middle.replacement import MiddleReplacementPattern
21 from mo.ops.op import Op
22 from mo.ops.reshape import Reshape
23
24
25 class GRUAndRNNToTensorIterator(MiddleReplacementPattern):
26     """ Converts normalized RNNSequence with op=GRU/RNN to TensorIterator.
27
28         Normalized RNNSequence means that it should be processed by
29         RNNSequenceNormalize transform that ensures its strict form.
30
31         This transformation builds an alternative sub-graph for GRUSequence
32         with TensorIterator connected in the same way as an original GRUSequence
33         node and with internal body represented as GRUCell op node with necessary
34         squeezes and unsqueezes around.
35     """
36
37     enabled = True
38     id = 'gru_and_rnn_to_tensor_iterator'
39
40     def run_after(self):
41         from extensions.middle.RNNSequenceNormalizeToIE import RNNSequenceNormalize
42         return [RNNSequenceNormalize]
43
44     def run_before(self):
45         from extensions.middle.FusePermutesSequence import FusePermutesSequence
46         return [FusePermutesSequence]
47
48     def pattern(self):
49         return dict(
50             nodes=[
51                 ('rnn_layer', dict(kind='op', type='RNNSequence')),
52                 ('input', dict(kind='data')),
53                 ('weights', dict(kind='data')),
54                 ('biases', dict(kind='data')),
55                 # don't capture optional input initial states here
56                 ('output', dict(kind='data')),
57                 # don't capture optional output last states here
58             ],
59             edges=[
60                 ('input', 'rnn_layer', {'in': 0}),
61                 ('weights', 'rnn_layer', {'bin': 'weights', 'in': 1}),
62                 ('biases', 'rnn_layer', {'bin': 'biases', 'in': 2}),
63                 ('rnn_layer', 'output', {'out': 0}),
64             ]
65         )
66
67     @staticmethod
68     def get_rnn_cell(name: str):
69         op = Op.get_op_class_by_name(name + 'Cell')
70         return op
71
72     def replace_pattern(self, graph: Graph, match: dict):
73         if match['rnn_layer']['op'] == 'LSTM':
74             return
75
76         rnn_layer = match['rnn_layer']
77
78         # Build TensorIterator body first
79         body = Graph(name=rnn_layer.name + '/sub_graph')
80         body.graph = graph.graph
81
82         # 1. Input squeeze Reshape
83         inputs = [Op._create_data_node(body, rnn_layer.name + '/inport/' + str(inp),
84                                        {'shape': rnn_layer.in_node(inp).shape.copy(),
85                                         'value': rnn_layer.in_node(inp).value.copy()
86                                         if rnn_layer.in_node(inp).value is not None and inp in [1, 2] else None})
87                   for inp in [0, 4, 1, 2]]  # X, h_init, WR, B
88
89         inputs[0].shape[rnn_layer.sequence_dim] = 1
90         reshape_dim = inputs[0].shape.copy()
91         reshape_dim[rnn_layer.batch_dim] = -1
92         reshape_dim = np.delete(reshape_dim, rnn_layer.sequence_dim)
93         input_squeeze = Reshape(
94             body,
95             dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0, dim=reshape_dim)
96         )
97         inputs[0] = input_squeeze.create_node_with_data([inputs[0]], edge_attrs=[{'internal_port_id': 0}])
98
99         # 2. Output unsqueeze Reshape
100         outputs = [Op._create_data_node(body, rnn_layer.name + '/outport/' + str(out),
101                                         {'shape': rnn_layer.out_node(out).shape.copy() if out in rnn_layer.out_nodes() else None})
102                                         for out in [0]]
103         for out in outputs:
104             add_opoutput(body, out.id, 0, False)
105
106         unsqueezed_output_shape = outputs[0].shape.copy()
107         unsqueezed_output_shape[rnn_layer.sequence_dim] = 1
108         squeezed_output_shape = np.delete(unsqueezed_output_shape, rnn_layer.sequence_dim)
109         outputs[0].shape = squeezed_output_shape
110         unsqueezed_output_shape[rnn_layer.batch_dim] = -1
111         output_unsqueeze = Reshape(body, dict(name=rnn_layer.name + '/output_unsqueeze/', dim=unsqueezed_output_shape,
112                                               internal_layer_id=2))
113
114         additional_attrs = dict(activations=rnn_layer.activations,
115                                 activation_alpha=rnn_layer.activation_alpha,
116                                 activation_beta=rnn_layer.activation_beta,
117                                 clip=rnn_layer.clip)
118         if rnn_layer.op == 'GRU':
119             additional_attrs['linear_before_reset'] = rnn_layer.linear_before_reset
120
121         # 3. ***Cell
122         rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])(body, dict(hidden_size=rnn_layer.hidden_size,
123                                                                     name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op),
124                                                                     **additional_attrs,
125                                                                     internal_layer_id=1))
126
127         gru_cell = rnn_cell_op.create_node_with_data(inputs, data_nodes=outputs,
128                                                              edge_attrs=[{}, {'internal_port_id': 1},
129                                                                         {'internal_port_id': 2}, {'bin': 'weights'},
130                                                                         {'bin': 'biases'}])
131
132         # internal ports for outputs of cell
133         gru_cell.in_node().out_edge(0)['internal_port_id'] = 4  # h_state
134
135         gru_cell = output_unsqueeze.create_node_with_data([gru_cell])
136         gru_cell.in_node().out_edge(0)['internal_port_id'] = 3
137         add_opoutput(body, gru_cell.id, 0, False)
138
139         # 4. TensorIterator layer creating
140         assert rnn_layer.direction in ['forward', 'reverse']
141         if rnn_layer.direction == 'forward':
142             stride = 1
143             start = None
144             end = None
145         else:
146             assert rnn_layer.direction == 'reverse'
147             stride = -1
148             start = -1
149             end = 0
150
151         # stacked h_state
152         output_port_map = [{
153             'external_port_id': 3,
154             'internal_layer_id': 2,
155             'internal_port_id': 3,
156
157             'axis': rnn_layer.sequence_dim,
158             'stride': stride,
159             'start': start,
160             'end': end,
161             'part_size': 1,
162         }]
163
164         # Adding last h_state to outputs
165         if len(rnn_layer.out_nodes()) == 2:
166             output_port_map.extend([{
167                 'external_port_id': 4,
168                 'internal_layer_id': 1,
169                 'internal_port_id': 4,
170             }])
171
172         ti_op = TensorIterator(graph, {
173             'name': rnn_layer.name + '/TensorIterator',
174             'body': body,
175             'in_ports_count': 4,
176             'out_ports_count': len(rnn_layer.out_nodes()),
177
178             'input_port_map': [
179                 {
180                     'external_port_id': 0,
181                     'internal_layer_id': 0,
182                     'internal_port_id': 0,
183
184                     'axis': rnn_layer.sequence_dim,
185                     'stride': stride,
186                     'start': start,
187                     'end': end,
188                     'part_size': 1,
189                 },
190                 {
191                     'external_port_id': 1,
192                     'internal_layer_id': 1,
193                     'internal_port_id': 1,
194                 },
195             ],
196
197             'output_port_map': output_port_map,
198             # only for h state
199             'back_edges': [
200                 {
201                     'from_layer': 1,
202                     'from_port': 4,
203                     'to_layer': 1,
204                     'to_port': 1,
205                 },
206             ]
207         })
208
209         assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \
210             "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id)
211
212         outs = ti_op.create_node_with_data([rnn_layer.in_node(i) for i in [0, 4]],  # X, h_init
213                                            data_nodes=[rnn_layer.out_node(i) for i in range(len(rnn_layer.out_nodes()))],
214                                            edge_attrs=[{'external_port_id': 0}, {'external_port_id': 1}])
215
216         if not isinstance(outs, list):
217             outs = list([outs])
218
219         graph.remove_node(rnn_layer.id)
220         outs[0].in_edge(0)['external_port_id'] = 3
221         for i, out in enumerate(outs[1:]):
222             external_port_id = 4 + i
223             out.in_edge()['external_port_id'] = external_port_id