Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / TF_lstm_cell_to_generic.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18
19 from extensions.middle.FusePermutesSequence import FusePermutesSequence
20 from mo.graph.graph import Graph
21 from mo.middle.replacement import MiddleReplacementPattern
22
23
24 class TensorFlowLSTMtoGeneric(MiddleReplacementPattern):
25     """
26     Resolves all differences in TensorFlow LSTMCell and Inference Engine LSTMCell:
27     - weights transposing
28     - shift_const value addition to biases
29     - extra inputs deletion
30     """
31     enabled = True
32
33     def run_after(self):
34         from extensions.middle.pass_separator import MiddleStart
35         return [MiddleStart]
36
37     def run_before(self):
38         return [
39             FusePermutesSequence,
40         ]
41
42     def pattern(self):
43         return dict(
44             nodes=[('lstm', dict(op='LSTMCell', tf=True))],
45             edges=[]
46         )
47
48     def replace_pattern(self, graph: Graph, match: dict):
49         weights_node = match['lstm'].in_node(3)
50         biases_node = match['lstm'].in_node(4)
51         node = match['lstm']
52         shift_const = node.shift_const
53
54         # make sure that the node is the only consumer or weights and biases
55         # to let us modify them without hassle
56         assert len(weights_node.out_nodes()) == 1
57         assert len(biases_node.out_nodes()) == 1
58
59         # Assign temporary shape for them for easier manipulation
60         # TF stores weights in IO order
61         input_size = node.in_node(0).shape[1]
62         hidden_size = node.in_node(1).shape[1]
63         weights = weights_node.value
64         biases = biases_node.value
65         assert weights.shape[0] == input_size + hidden_size, \
66             "weights.shape={} input_size={} hidden_size={}".format(weights.shape, input_size, hidden_size)
67         assert weights.shape[1] == biases.shape[0] == 4 * hidden_size, \
68             "weights.shape={} biases.shape={} hidden_size={}".format(weights.shape, biases.shape, hidden_size)
69
70         weights = weights.reshape([
71             weights.shape[0],
72             4,  # gates
73             hidden_size
74         ])
75
76         biases = biases.reshape([
77             4,  # gates
78             hidden_size
79         ])
80
81         # Reorder gates icfo --> fico for both weights and biases
82         gate_reorder = [2, 0, 1, 3]
83         weights = np.take(weights, gate_reorder, axis=1)
84         biases = np.take(biases, gate_reorder, axis=0)
85
86         # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0)
87         # Note: in case of moving this code up before gate reordering, the addition
88         # should be applied at different place
89         biases[0] += shift_const
90
91         # Return to the original shapes
92         weights = weights.reshape([weights.shape[0], -1])
93         biases = biases.flatten()
94
95         # TF stores weights in IO, but IE requires it in OI: transpose
96         weights = weights.transpose()
97
98         weights_node.value = weights
99         weights_node.shape = np.array(weights.shape, dtype=np.int64)
100         biases_node.value = biases
101         biases_node.shape = np.array(biases.shape, dtype=np.int64)
102
103         # Cut all extra inputs off
104         for i in range(len(node.inputs), len(node.inputs) + len(node.extra_inputs)):
105             node.graph.remove_edge(node.in_node(i).id, node.id)