Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / kaldi / replace_lstm_node_pattern.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import numpy as np
17
18 from mo.front.caffe.extractors.utils import embed_input
19 from mo.front.common.replacement import FrontReplacementOp
20 from mo.graph.graph import Node, Graph
21 from mo.ops.activation import Activation
22 from mo.ops.clamp import Clamp
23 from mo.ops.eltwise import Eltwise
24 from mo.ops.inner_product import InnerProduct
25 from mo.ops.memory import Memory
26 from mo.ops.scale_shift import ScaleShiftOp
27 from mo.ops.split import Split
28
29
30 def unique_id(prefix: str = 'id') -> str:
31     """
32     Generates a unique id
33     The optional string prefix can be specified.
34     """
35     index = len(unique_id.names)
36     name = prefix
37     while name in unique_id.names:
38         name = '{}_{}'.format(prefix, index)
39         index += 1
40     unique_id.names.append(name)
41     return name
42
43
44 unique_id.names = []
45
46
47 class ReplaceLSTMNodePattern(FrontReplacementOp):
48     op = "LSTMCell"
49     enabled = True
50
51     # we need to rewrite this transform to fit unified pipeline (it should be a part of traditional FRONT phase)
52     def run_before(self):
53         from extensions.front.output_cut import OutputCut
54         return [OutputCut]
55
56     def run_after(self):
57         return []
58
59     def replace_op(self, graph: Graph, node: Node):
60         input_node = node.in_node()
61
62         memory_pair_input = unique_id('id')
63         memory_pair_output = unique_id('id')
64
65         # Input -> FullyConnected
66         fc_layer_after_input_attrs = {'name': 'input_fullyconnected',
67                                       'num_output': node.gifo_x_weights_shape[0],
68                                       'bias_term': True
69                                       }
70
71         embed_input(fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
72         embed_input(fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
73         fc_layer_after_input = InnerProduct(graph, fc_layer_after_input_attrs).create_node([input_node])
74
75         prev_lstm_output = Memory(graph, {'name': 'prev_memory_output',
76                                           'id': memory_pair_input,
77                                           'index': 1,
78                                           'size': 2,
79                                           'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
80                                           }).create_node()
81
82         # *Memory(output) -> FullyConnected
83         fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
84                                           'num_output': node.gifo_r_weights_shape[0],
85                                           'bias_term': False
86                                           }
87
88         embed_input(fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
89         fc_layer_from_prev_state = InnerProduct(graph, fc_layer_from_prev_state_attrs).create_node(
90             [prev_lstm_output])
91
92         # Memory -> FullyConnected  \
93         #                           *Eltwise(sum)
94         # Input -> FullyConnected   /
95         join_input_prev_state_sum = Eltwise(graph, {'name': 'join_input_eltwise',
96                                                     'operation': 'sum'
97                                                     }).create_node([fc_layer_from_prev_state,
98                                                                     fc_layer_after_input])
99
100         # *Eltwise(sum) -> Split
101         # it is split into 4 nodes: Act, Eltw*3
102         # the following order is mandatory
103         #       ___Tanh
104         #      /
105         # Split ---(2)Eltwise(sum)
106         #     |\
107         #     | \__(3)Eltwise(sum)
108         #     |____(4)Eltwise(sum)
109         split_joined_input = Split(graph, {'name': 'join_input_split',
110                                            'axis': 1,
111                                            'num_split': 4,
112                                            'out_ports_count': 4,
113                                            }).create_node([join_input_prev_state_sum])
114
115         prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
116                                          'id': memory_pair_output,
117                                          'index': 1,
118                                          'size': 2,
119                                          'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
120                                          }).create_node()
121
122         # *Memory(state) -> *ScaleShift(input)
123         state_input_scaleshift_attrs = {'name': 'input_scaleshift',
124                                         'bias_term': False
125                                         }
126         embed_input(state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
127         state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])
128
129         # *Memory(state) -> *ScaleShift(forget)
130         state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
131                                          'bias_term': False
132                                          }
133         embed_input(state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
134         state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state])
135
136         # Split                                 \
137         #                                       (2)Eltwise(sum)
138         # Memory(state) -> *ScaleShift(input)  /
139         join_prev_lstm_input_joined_input_sum = Eltwise(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise',
140                                                                 'operation': 'sum'
141                                                                 }).create_node([(split_joined_input, 1),
142                                                                                 state_input_scaleshift
143                                                                                 ])
144         # Split                                 \
145         #                                       (3)Eltwise(sum)
146         # Memory(state) -> *ScaleShift(forget)  /
147         join_prev_lstm_input_joined_forget_sum = Eltwise(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
148                                                                  'operation': 'sum'
149                                                                  }).create_node([(split_joined_input, 2),
150                                                                                  state_forget_scaleshift
151                                                                                  ])
152
153         # Split -> Tanh
154         remember_tahn = Activation(graph, {'name': 'remember_tahnv',
155                                            'operation': 'tanh'
156                                            }).create_node([(split_joined_input, 0)])
157
158         # Split -> (2)Eltwise(sum) -> *Sigmoid
159         remember_sigmoid = Activation(graph, {'name': 'remember_sigmoid',
160                                               'operation': 'sigmoid'
161                                               }).create_node(
162             [join_prev_lstm_input_joined_input_sum])
163
164         # Split -> (3)Eltwise(sum) -> **Sigmoid
165         forget_sigmoid = Activation(graph, {'name': 'forget_sigmoid',
166                                             'operation': 'sigmoid'
167                                             }).create_node(
168             [join_prev_lstm_input_joined_forget_sum])
169
170         # *Memory(state)                        \
171         #                                       (6)Eltwise(mul)
172         # Split -> (3)Eltwise(sum) -> **Sigmoid /
173         join_forget_prev_state_mul = Eltwise(graph, {'name': 'join_forget_prev_state_mul',
174                                                      'operation': 'mul'
175                                                      }).create_node(
176             [forget_sigmoid, prev_lstm_state])
177
178         # Split -> Tahn                         \
179         #                                       (5)Eltwise(mul)
180         # Split -> (2)Eltwise(sum) -> *Sigmoid   /
181         join_remember_candidates_mul = Eltwise(graph, {'name': 'join_remember_candidates_mul',
182                                                        'operation': 'mul'
183                                                        }).create_node(
184             [remember_tahn, remember_sigmoid])
185
186         # (5)Eltwise(mul)  \
187         #               (7)Eltwise(sum)
188         # (6)Eltwise(mul)   /
189         join_forget_remember_sum = Eltwise(graph, {'name': 'join_forget_remember_sum',
190                                                    'operation': 'sum'
191                                                    }).create_node(
192             [join_forget_prev_state_mul, join_remember_candidates_mul])
193
194         # (7)Eltwise(sum) -> Clamp
195         join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
196                                           'max': node.clip_value,
197                                           'min': -node.clip_value
198                                           }).create_node(
199             [join_forget_remember_sum])
200         #
201         # Clamp -> (2)Memory(state)
202         Memory(graph, {'name': 'next_lstm_state',
203                        'id': memory_pair_output,
204                        'index': 0,
205                        'size': 2,
206                        'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
207                        }).create_node([join_forget_clamp])
208
209         # Clamp -> (2)Tahn
210         state_filtered_tahn = Activation(graph, {'name': 'state_filtered_tahn',
211                                                  'operation': 'tanh'
212                                                  }).create_node([join_forget_clamp])
213
214         # Clamp -> (2)ScaleShift
215         clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
216                                   'bias_term': False}
217         embed_input(clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
218         clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])
219
220         # Split                 \
221         #                       (4)Eltwise(sum)
222         # Clamp -> (2)ScaleShift /
223         join_next_lstm_input_joined_input_sum = Eltwise(graph, {'name': 'join_next_lstm_input_joined_input_sum',
224                                                                 'operation': 'sum'
225                                                                 }).create_node([(split_joined_input, 3), clamp_scaleshift])
226
227         # (4)Eltwise(sum) -> (3)Sigmoid
228         output_sigmoid = Activation(graph, {'name': 'output_sigmoid',
229                                             'operation': 'sigmoid'
230                                             }).create_node(
231             [join_next_lstm_input_joined_input_sum])
232
233         # (4)Eltwise(sum) -> (3)Sigmoid         \
234         #                                       (5)Eltwise(mul)
235         # Clamp -> (2)Tahn                      /
236         joined_output_mul = Eltwise(graph, {'name': 'joined_output_mul',
237                                             'operation': 'mul'
238                                             }).create_node([state_filtered_tahn, output_sigmoid])
239
240         # (5)Eltwise(mul) -> (3)FullyConnected
241         fc_output_attrs = {'name': 'FullyConnected',
242                            'num_output': node.projection_weights_shape[0],
243                            'bias_term': False}
244         embed_input(fc_output_attrs, 1, 'weights', node.projection_weights)
245         fc_output = InnerProduct(graph, fc_output_attrs).create_node([joined_output_mul])
246
247         #                   / (2)Memory(output)
248         # (3)FullyConnected
249         #                   \ Output (any next node) (edge created automatically after replacement)
250         Memory(graph, {'name': 'next_lstm_output',
251                        'id': memory_pair_input,
252                        'index': 0,
253                        'size': 2,
254                        'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
255                        }).create_node([fc_output])
256
257         return [fc_output.id]