2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from extensions.middle.FusePermutesSequence import FusePermutesSequence
21 from mo.middle.replacement import MiddleReplacementPattern
24 class TensorFlowLSTMtoGeneric(MiddleReplacementPattern):
26 Resolves all differences in TensorFlow LSTMCell and Inference Engine LSTMCell:
28 - shift_const value addition to biases
29 - extra inputs deletion
43 nodes=[('lstm', dict(op='LSTMCell', tf=True))],
47 def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
48 weights_node = match['lstm'].in_node(3)
49 biases_node = match['lstm'].in_node(4)
51 shift_const = node.shift_const
53 # make sure that the node is the only consumer or weights and biases
54 # to let us modify them without hassle
55 assert len(weights_node.out_nodes()) == 1
56 assert len(biases_node.out_nodes()) == 1
58 # Assign temporary shape for them for easier manipulation
59 # TF stores weights in IO order
60 input_size = node.in_node(0).shape[1]
61 hidden_size = node.in_node(1).shape[1]
62 weights = weights_node.value
63 biases = biases_node.value
64 assert weights.shape[0] == input_size + hidden_size, "weights.shape={} input_size={} hidden_size={}".format(
65 weights.shape, input_size, hidden_size)
66 assert weights.shape[1] == biases.shape[0] == 4 * hidden_size,\
67 "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
69 weights = weights.reshape([
75 biases = biases.reshape([
80 # Reorder gates icfo --> fico for both weights and biases
81 gate_reorder = [2, 0, 1, 3]
82 weights = np.take(weights, gate_reorder, axis=1)
83 biases = np.take(biases, gate_reorder, axis=0)
85 # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
86 # Note: in case of moving this code up before gate reordering, the addition
87 # should be applied at different place
88 biases[0] += shift_const
90 # Return to the original shapes
91 weights = weights.reshape([weights.shape[0], -1])
92 biases = biases.flatten()
94 # TF stores weights in IO, but IE requires it in OI: transpose
95 weights = weights.transpose()
97 weights_node.value = weights
98 weights_node.shape = np.array(weights.shape, dtype=np.int64)
99 biases_node.value = biases
100 biases_node.shape = np.array(biases.shape, dtype=np.int64)
102 # Cut all extra inputs off
103 for i in range(len(node.inputs), len(node.inputs) + len(node.extra_inputs)):
104 node.graph.remove_edge(node.in_node(i).id, node.id)