2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from extensions.middle.FusePermutesSequence import FusePermutesSequence
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
24 class TensorFlowLSTMtoGeneric(MiddleReplacementPattern):
26 Resolves all differences in TensorFlow LSTMCell and Inference Engine LSTMCell:
28 - shift_const value addition to biases
29 - extra inputs deletion
34 from extensions.middle.pass_separator import MiddleStart
44 nodes=[('lstm', dict(op='LSTMCell', tf=True))],
48 def replace_pattern(self, graph: Graph, match: dict):
49 weights_node = match['lstm'].in_node(3)
50 biases_node = match['lstm'].in_node(4)
52 shift_const = node.shift_const
54 # make sure that the node is the only consumer or weights and biases
55 # to let us modify them without hassle
56 assert len(weights_node.out_nodes()) == 1
57 assert len(biases_node.out_nodes()) == 1
59 # Assign temporary shape for them for easier manipulation
60 # TF stores weights in IO order
61 input_size = node.in_node(0).shape[1]
62 hidden_size = node.in_node(1).shape[1]
63 weights = weights_node.value
64 biases = biases_node.value
65 assert weights.shape[0] == input_size + hidden_size, \
66 "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
67 assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
68 "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
70 weights = weights.reshape([
76 biases = biases.reshape([
81 # Reorder gates icfo --> fico for both weights and biases
82 gate_reorder = [2, 0, 1, 3]
83 weights = np.take(weights, gate_reorder, axis=1)
84 biases = np.take(biases, gate_reorder, axis=0)
86 # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
87 # Note: in case of moving this code up before gate reordering, the addition
88 # should be applied at different place
89 biases[0] += shift_const
91 # Return to the original shapes
92 weights = weights.reshape([weights.shape[0], -1])
93 biases = biases.flatten()
95 # TF stores weights in IO, but IE requires it in OI: transpose
96 weights = weights.transpose()
98 weights_node.value = weights
99 weights_node.shape = np.array(weights.shape, dtype=np.int64)
100 biases_node.value = biases
101 biases_node.shape = np.array(biases.shape, dtype=np.int64)
103 # Cut all extra inputs off
104 for i in range(len(node.inputs), len(node.inputs) + len(node.extra_inputs)):
105 node.graph.remove_edge(node.in_node(i).id, node.id)