2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from mo.front.caffe.extractors.utils import embed_input
19 from mo.front.common.replacement import FrontReplacementOp
20 from mo.graph.graph import Node, Graph
21 from mo.ops.activation import Activation
22 from mo.ops.clamp import Clamp
23 from mo.ops.eltwise import Eltwise
24 from mo.ops.inner_product import InnerProduct
25 from mo.ops.memory import Memory
26 from mo.ops.scale_shift import ScaleShiftOp
27 from mo.ops.split import Split
30 def unique_id(prefix: str = 'id') -> str:
33 The optional string prefix can be specified.
35 index = len(unique_id.names)
37 while name in unique_id.names:
38 name = '{}_{}'.format(prefix, index)
40 unique_id.names.append(name)
47 class ReplaceLSTMNodePattern(FrontReplacementOp):
51 # we need to rewrite this transform to fit unified pipeline (it should be a part of traditional FRONT phase)
53 from extensions.front.output_cut import OutputCut
59 def replace_op(self, graph: Graph, node: Node):
60 input_node = node.in_node()
62 memory_pair_input = unique_id('id')
63 memory_pair_output = unique_id('id')
65 # Input -> FullyConnected
66 fc_layer_after_input_attrs = {'name': 'input_fullyconnected',
67 'num_output': node.gifo_x_weights_shape[0],
71 embed_input(fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
72 embed_input(fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
73 fc_layer_after_input = InnerProduct(graph, fc_layer_after_input_attrs).create_node([input_node])
75 prev_lstm_output = Memory(graph, {'name': 'prev_memory_output',
76 'id': memory_pair_input,
79 'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
82 # *Memory(output) -> FullyConnected
83 fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
84 'num_output': node.gifo_r_weights_shape[0],
88 embed_input(fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
89 fc_layer_from_prev_state = InnerProduct(graph, fc_layer_from_prev_state_attrs).create_node(
92 # Memory -> FullyConnected \
94 # Input -> FullyConnected /
95 join_input_prev_state_sum = Eltwise(graph, {'name': 'join_input_eltwise',
97 }).create_node([fc_layer_from_prev_state,
98 fc_layer_after_input])
100 # *Eltwise(sum) -> Split
101 # it is split into 4 nodes: Act, Eltw*3
102 # the following order is mandatory
105 # Split ---(2)Eltwise(sum)
107 # | \__(3)Eltwise(sum)
108 # |____(4)Eltwise(sum)
109 split_joined_input = Split(graph, {'name': 'join_input_split',
112 'out_ports_count': 4,
113 }).create_node([join_input_prev_state_sum])
115 prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
116 'id': memory_pair_output,
119 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
122 # *Memory(state) -> *ScaleShift(input)
123 state_input_scaleshift_attrs = {'name': 'input_scaleshift',
126 embed_input(state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
127 state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])
129 # *Memory(state) -> *ScaleShift(forget)
130 state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
133 embed_input(state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
134 state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state])
138 # Memory(state) -> *ScaleShift(input) /
139 join_prev_lstm_input_joined_input_sum = Eltwise(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise',
141 }).create_node([(split_joined_input, 1),
142 state_input_scaleshift
146 # Memory(state) -> *ScaleShift(forget) /
147 join_prev_lstm_input_joined_forget_sum = Eltwise(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
149 }).create_node([(split_joined_input, 2),
150 state_forget_scaleshift
154 remember_tahn = Activation(graph, {'name': 'remember_tahnv',
156 }).create_node([(split_joined_input, 0)])
158 # Split -> (2)Eltwise(sum) -> *Sigmoid
159 remember_sigmoid = Activation(graph, {'name': 'remember_sigmoid',
160 'operation': 'sigmoid'
162 [join_prev_lstm_input_joined_input_sum])
164 # Split -> (3)Eltwise(sum) -> **Sigmoid
165 forget_sigmoid = Activation(graph, {'name': 'forget_sigmoid',
166 'operation': 'sigmoid'
168 [join_prev_lstm_input_joined_forget_sum])
172 # Split -> (3)Eltwise(sum) -> **Sigmoid /
173 join_forget_prev_state_mul = Eltwise(graph, {'name': 'join_forget_prev_state_mul',
176 [forget_sigmoid, prev_lstm_state])
180 # Split -> (2)Eltwise(sum) -> *Sigmoid /
181 join_remember_candidates_mul = Eltwise(graph, {'name': 'join_remember_candidates_mul',
184 [remember_tahn, remember_sigmoid])
189 join_forget_remember_sum = Eltwise(graph, {'name': 'join_forget_remember_sum',
192 [join_forget_prev_state_mul, join_remember_candidates_mul])
194 # (7)Eltwise(sum) -> Clamp
195 join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
196 'max': node.clip_value,
197 'min': -node.clip_value
199 [join_forget_remember_sum])
201 # Clamp -> (2)Memory(state)
202 Memory(graph, {'name': 'next_lstm_state',
203 'id': memory_pair_output,
206 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
207 }).create_node([join_forget_clamp])
210 state_filtered_tahn = Activation(graph, {'name': 'state_filtered_tahn',
212 }).create_node([join_forget_clamp])
214 # Clamp -> (2)ScaleShift
215 clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
217 embed_input(clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
218 clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])
222 # Clamp -> (2)ScaleShift /
223 join_next_lstm_input_joined_input_sum = Eltwise(graph, {'name': 'join_next_lstm_input_joined_input_sum',
225 }).create_node([(split_joined_input, 3), clamp_scaleshift])
227 # (4)Eltwise(sum) -> (3)Sigmoid
228 output_sigmoid = Activation(graph, {'name': 'output_sigmoid',
229 'operation': 'sigmoid'
231 [join_next_lstm_input_joined_input_sum])
233 # (4)Eltwise(sum) -> (3)Sigmoid \
236 joined_output_mul = Eltwise(graph, {'name': 'joined_output_mul',
238 }).create_node([state_filtered_tahn, output_sigmoid])
240 # (5)Eltwise(mul) -> (3)FullyConnected
241 fc_output_attrs = {'name': 'FullyConnected',
242 'num_output': node.projection_weights_shape[0],
244 embed_input(fc_output_attrs, 1, 'weights', node.projection_weights)
245 fc_output = InnerProduct(graph, fc_output_attrs).create_node([joined_output_mul])
247 # / (2)Memory(output)
249 # \ Output (any next node) (edge created automatically after replacement)
250 Memory(graph, {'name': 'next_lstm_output',
251 'id': memory_pair_input,
254 'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
255 }).create_node([fc_output])
257 return [fc_output.id]